/* eslint-disable react/no-unknown-property */
import './KLNiftiViewer.css'

import { CopyFilled, CopyTwoTone } from '@ant-design/icons'
import { Intent } from '@blueprintjs/core'
import { TrackballControls } from '@react-three/drei'
import { Canvas, useFrame, useThree } from '@react-three/fiber'
import { Button, Spin } from 'antd'
import { Colors } from 'common/colors'
import { Spacing } from 'common/stylings'
import Stack from 'components/Stack'
import Toaster from 'components/Toaster'
import Text from 'components/Typography'
import { gunzipSync } from 'fflate'
import * as nifti from 'nifti-reader-js'
import React, { memo, useEffect, useMemo, useRef, useState } from 'react'
import { connect } from 'react-redux'
import { bindActionCreators, Dispatch } from 'redux'
import { RootState } from 'store/rootReducer'
import { editVideoCountFlow } from 'store/videoSegments/actions'
import * as THREE from 'three'

type Props = KLIndivNiftiViewerProps &
  StoreProps<typeof mapStateToProps, typeof mapDispatchToProps>

interface KLIndivNiftiViewerProps {
  niftiFileUrl: string
  fileName: string
  video: any
}

interface NiftiDataProps {
  dims: number[]
  pixDims: number[]
  voxelData: Float32Array
}

const VERTEX_SHADER = `
  varying vec3 vPosition;
  void main() {
    vPosition = position;
    gl_Position = projectionMatrix * modelViewMatrix * vec4(position, 1.0);
  }
`

const FRAGMENT_SHADER = `
  precision highp float;
  precision highp sampler3D;

  uniform sampler3D uTexture;
  uniform float uSteps;
  uniform float uMinValue;
  uniform vec3 cameraPos;
  uniform mat4 invModelMatrix;
  varying vec3 vPosition;
  const float gradStep = 1.0 / 256.0;
  const int MAX_STEPS = 768;

  vec2 rayBoxIntersect(vec3 origin, vec3 dir, vec3 boxMin, vec3 boxMax) {
    vec3 invDir = 1.0 / dir;
    vec3 tMin = (boxMin - origin) * invDir;
    vec3 tMax = (boxMax - origin) * invDir;
    vec3 t1 = min(tMin, tMax);
    vec3 t2 = max(tMin, tMax);
    float tNear = max(max(t1.x, t1.y), t1.z);
    float tFar  = min(min(t2.x, t2.y), t2.z);
    return vec2(tNear, tFar);
  }
  
  vec3 computeGradient(vec3 samplePos) {
      float dX = texture(uTexture, samplePos + vec3(gradStep, 0, 0)).r -
                 texture(uTexture, samplePos - vec3(gradStep, 0, 0)).r;
      float dY = texture(uTexture, samplePos + vec3(0, gradStep, 0)).r -
                 texture(uTexture, samplePos - vec3(0, gradStep, 0)).r;
      float dZ = texture(uTexture, samplePos + vec3(0, 0, gradStep)).r -
                 texture(uTexture, samplePos - vec3(0, 0, gradStep)).r;
      return normalize(vec3(dX, dY, dZ));
  }
  
  void main() {
    vec3 camPosObject = (invModelMatrix * vec4(cameraPos, 1.0)).xyz;
    vec3 rayOrigin = camPosObject;
    vec3 rayDir = normalize(vPosition - camPosObject);

    vec2 bounds = rayBoxIntersect(rayOrigin, rayDir, vec3(-0.5), vec3(0.5));
    if (bounds.x > bounds.y) discard;

    float tNear = max(bounds.x, 0.0);
    float tFar = bounds.y;
    vec3 start = rayOrigin + tNear * rayDir;
    vec3 end = rayOrigin + tFar * rayDir;
    float rayLength = distance(start, end);
    float stepSize = rayLength / uSteps;

    vec4 accum = vec4(0.0);
    vec3 currentPos = start;

    for (int i = 0; i < MAX_STEPS; i++) {
      if (i >= int(uSteps)) break;
      vec3 samplePos = currentPos + vec3(0.5);
      float density = texture(uTexture, samplePos).r;
      
      if (density > uMinValue) {
          vec3 normal = computeGradient(samplePos);
          if (length(normal) < 0.01) normal = vec3(0.0, 1.0, 0.0);
          float shadowFactor = clamp(dot(normal, normalize(vec3(-0.5, -0.5, -0.5))), 0.5, 1.0);
          float lightIntensity = clamp(dot(normal, normalize(vec3(1.0, 1.0, 1.0))), 0.6, 1.5);
          lightIntensity *= shadowFactor;
          float opacity = density * 0.9;
          float shade = mix(1.2, 1.8, density);
          vec3 baseColor = vec3(0.6078, 0.6549, 0.5216) * 1.5;
          vec3 color = baseColor * shade * lightIntensity;
          accum.rgb += (1.0 - accum.a) * color * opacity;
          accum.a += (1.0 - accum.a) * opacity;
          if (accum.a >= 0.98) break;
      }
      
      currentPos += rayDir * (stepSize * 1.2);
    }
    
    if (accum.a == 0.0) discard;
    gl_FragColor = accum;
  }
`

const sharedBoxGeometry = new THREE.BoxGeometry(1, 1, 1)

const VolumeRenderer = memo(function VolumeRenderer({
  niftiData,
}: {
  niftiData: NiftiDataProps
}) {
  const { dims, pixDims, voxelData } = niftiData
  const { gl } = useThree()

  const texture = useMemo(() => {
    const tex = new THREE.Data3DTexture(voxelData, dims[0], dims[1], dims[2])
    tex.format = THREE.RedFormat
    tex.type = THREE.FloatType
    tex.minFilter = THREE.LinearFilter
    tex.magFilter = THREE.LinearFilter
    tex.unpackAlignment = 1
    tex.needsUpdate = true
    tex.anisotropy = gl.capabilities.getMaxAnisotropy()
    return tex
  }, [dims, voxelData, gl])

  const materialRef = useRef<THREE.ShaderMaterial>(null)
  const meshRef = useRef<THREE.Mesh>(null)

  const volumeMaterial = useMemo(
    () =>
      new THREE.ShaderMaterial({
        side: THREE.BackSide,
        transparent: true,
        depthWrite: false,
        depthTest: false,
        uniforms: {
          uTexture: { value: texture },
          uSteps: { value: 768.0 },
          uMinValue: { value: 0.5 },
          cameraPos: { value: new THREE.Vector3() },
          invModelMatrix: { value: new THREE.Matrix4() },
        },
        vertexShader: VERTEX_SHADER,
        fragmentShader: FRAGMENT_SHADER,
      }),
    [texture],
  )

  useFrame(({ camera }) => {
    if (!materialRef.current || !meshRef.current) return
    materialRef.current.uniforms.cameraPos.value.copy(camera.position)
    materialRef.current.uniforms.invModelMatrix.value
      .copy(meshRef.current.matrixWorld)
      .invert()
  })

  const scaleVec: [number, number, number] = useMemo(() => {
    const sizeX = dims[0] * pixDims[0]
    const sizeY = dims[1] * pixDims[1]
    const sizeZ = dims[2] * pixDims[2]
    const maxDim = Math.max(sizeX, sizeY, sizeZ)
    return [sizeX / maxDim, sizeY / maxDim, sizeZ / maxDim]
  }, [dims, pixDims])

  return (
    <mesh ref={meshRef} scale={scaleVec} geometry={sharedBoxGeometry}>
      <primitive ref={materialRef} attach="material" object={volumeMaterial} />
    </mesh>
  )
})

const KLIndivNiftiViewer: React.FC<Props> = ({
  niftiFileUrl,
  fileName,
  video,
  editVideoCount,
}) => {
  const [niftiData, setNiftiData] = useState<NiftiDataProps | null>(null)
  const [errorMessage, setErrorMessage] = useState<string | null>(null)
  const [clickCopy, setClickCopy] = useState(false)
  const [isDownloaded, setIsDownloaded] = useState(false)
  const [isLoading, setIsLoading] = useState(true)

  useEffect(() => {
    if (!niftiFileUrl || niftiFileUrl.trim() === '') return

    let cancelled = false

    const loadNiftiFile = async (url: string): Promise<NiftiDataProps> => {
      const response = await fetch(url)
      if (!response.ok) {
        throw new Error(`Failed to fetch: ${response.statusText}`)
      }
      let buffer = await response.arrayBuffer()
      if (nifti.isCompressed(buffer)) {
        buffer = gunzipSync(new Uint8Array(buffer)).buffer
      }
      if (!nifti.isNIFTI(buffer)) {
        throw new Error('Invalid NIfTI file.')
      }
      const header = nifti.readHeader(buffer)
      if (!header) {
        throw new Error('Invalid NIfTI file: Could not read header.')
      }
      const image = nifti.readImage(header, buffer)
      const dims = [header.dims[1], header.dims[2], header.dims[3]]
      const pixDims = [header.pixDims[1], header.pixDims[2], header.pixDims[3]]
      const rawData = new Int16Array(image)
      const voxelData = new Float32Array(rawData.length)
      let maxVal = 0
      for (let i = 0; i < rawData.length; i++) {
        if (Math.abs(rawData[i]) > maxVal) {
          maxVal = Math.abs(rawData[i])
        }
      }
      for (let i = 0; i < rawData.length; i++) {
        voxelData[i] = Math.abs(rawData[i]) / maxVal
      }
      return { dims, pixDims, voxelData }
    }

    const loadNiftiData = async () => {
      try {
        const loadedData = await loadNiftiFile(niftiFileUrl)
        if (!cancelled) {
          setNiftiData(loadedData)
        }
      } catch (error) {
        if (!cancelled) {
          setErrorMessage(
            (error as Error).message || 'Failed to load NIfTI file.',
          )
        }
      } finally {
        if (!cancelled) {
          setIsLoading(false)
        }
      }
    }

    loadNiftiData()

    return () => {
      cancelled = true
    }
  }, [niftiFileUrl])

  const renderCopyButton = () => {
    const toggleAfterCopy = () => {
      setClickCopy(true)
      navigator.clipboard.writeText(video.uri)
      Toaster.show({
        icon: 'tick',
        intent: Intent.SUCCESS,
        message: 'Copied to Clipboard!',
      })
    }
    return !clickCopy ? (
      <CopyTwoTone style={{ marginLeft: '5px' }} onClick={toggleAfterCopy} />
    ) : (
      <CopyFilled style={{ marginLeft: '5px' }} />
    )
  }

  const handleDownload = () => {
    if (!isDownloaded) {
      editVideoCount({
        videoId: video._id,
        download_count: video.download_count + 1,
        batch: video.batch,
      })
      setIsDownloaded(true)
    }
  }

  const renderDownloadButton = () => (
    <Stack justifyContent="flex-start">
      <a
        href={niftiFileUrl}
        download
        style={{ textDecoration: 'none' }}
        onClick={handleDownload}>
        <Button
          disabled={isDownloaded}
          type="primary"
          style={{ marginTop: '0.5rem', marginBottom: '0.75rem' }}>
          Download 3D Asset
        </Button>
      </a>
    </Stack>
  )

  const renderUserInstructions = () => (
    <div
      style={{
        marginBottom: '0.75rem',
        fontSize: '1rem',
        color: 'black',
      }}>
      Press and hold &apos;D&apos; and the left mouse button, then drag to
      explore the scene. Drag your mouse to rotate the view and use the scroll
      wheel to zoom in and out.
    </div>
  )

  return (
    <Stack vertical gutter={Spacing.MEDIUM}>
      <Stack vertical>
        <Text
          style={{ fontSize: '18px', paddingTop: '12px' }}
          fontSize={16}
          fontWeight="bold">
          {fileName}
        </Text>
        <Text color={Colors.PALE_GREY} fontSize={12}>
          {video.uri}
          {renderCopyButton()}
          {renderDownloadButton()}
        </Text>
        {renderUserInstructions()}
        <div
          style={{
            width: '100%',
            height: '100%',
            display: 'flex',
            justifyContent: 'center',
            alignItems: 'center',
          }}>
          {isLoading ? (
            <Spin tip="Loading NIFTI file..." />
          ) : errorMessage ? (
            <p style={{ color: '#fff' }}>{errorMessage}</p>
          ) : niftiData ? (
            <Canvas
              className="custom-canvas"
              dpr={[1, 1.5]}
              gl={{ antialias: true }}
              camera={{ position: [0.9, 0.9, 0.9], fov: 45 }}
              style={{ flex: 1, background: '#000' }}>
              <ambientLight intensity={1.0} />
              <VolumeRenderer niftiData={niftiData} />
              <TrackballControls rotateSpeed={2} />
            </Canvas>
          ) : null}
        </div>
      </Stack>
    </Stack>
  )
}

const mapStateToProps = (state: RootState) => ({})
const mapDispatchToProps = (dispatch: Dispatch) =>
  bindActionCreators(
    {
      editVideoCount: editVideoCountFlow,
    },
    dispatch,
  )

export default connect(mapStateToProps, mapDispatchToProps)(KLIndivNiftiViewer)
