import { useFrame, useThree } from '@react-three/fiber'
import PropTypes from 'prop-types'
import {
  forwardRef,
  memo,
  useEffect,
  useMemo,
  useRef
} from 'react'
import {
  ClampToEdgeWrapping,
  FloatType,
  Mesh,
  NearestFilter,
  PlaneGeometry,
  RGBAFormat,
  ShaderMaterial,
  Vector2,
  WebGLRenderTarget
} from 'three'
import useFallbackRef from '@/lib/react/hooks/useFallbackRef'
import fadeFrag from './fade.frag'
import fadeVert from './fade.vert'
import paintFrag from './paint.frag'
import paintVert from './paint.vert'
import useGPUPicker from '@/gl/hooks/useGPUPicker'

const OFFSCREEN = { x: -1, y: -1 }

/**
 * The `PointerPainter`
 * @param {object} props - the component props
 * @returns {React.ReactElement} the element
 */
const PointerPainter = forwardRef((props, forwardedRef) => {
  const {
    brushRadius = 0.1,
    children,
    fadeSpeed = 0.5,
    fbo: fbo2 = new WebGLRenderTarget(1, 1, {
      minFilter: NearestFilter,
      magFilter: NearestFilter,
      format: RGBAFormat,
      type: FloatType,
      wrapS: ClampToEdgeWrapping,
      wrapT: ClampToEdgeWrapping
    }),
    isDragging = false,
    isHover = false,
    pointer = { state: { x: 0, y: 0 } },
    size = new Vector2(1, 1)
  } = props
  const { camera, gl, size: viewportSize, viewport } = useThree()
  const ref = useFallbackRef(forwardedRef)
  const picker = useGPUPicker()

  // FBOs
  const fbo1 = useRef(
    new WebGLRenderTarget(1, 1, {
      minFilter: NearestFilter,
      magFilter: NearestFilter,
      format: RGBAFormat,
      type: FloatType,
      wrapS: ClampToEdgeWrapping,
      wrapT: ClampToEdgeWrapping
    })
  )

  useEffect(() => {
    const pixelRatio = gl.getPixelRatio()
    const width = viewportSize.width * pixelRatio
    const height = viewportSize.height * pixelRatio

    // Dynamically update FBO sizes
    fbo1.current.setSize(width, height)
    fbo2.current.setSize(width, height)
  }, [viewportSize, gl])

  const paintMaterial = useMemo(
    () =>
      new ShaderMaterial({
        uniforms: {
          uTexture: { value: null },
          uPointer: { value: new Vector2() },
          uFadeSpeed: { value: fadeSpeed },
          uPress: { value: 0 },
          uBrushRadius: { value: brushRadius },
          uSize: { value: size }
        },
        vertexShader: paintVert,
        fragmentShader: paintFrag
      }),
    [brushRadius, fadeSpeed, size]
  )

  const fadeMaterial = useMemo(
    () =>
      new ShaderMaterial({
        uniforms: {
          uTexture: { value: null },
          uFadeSpeed: { value: fadeSpeed }
        },
        vertexShader: fadeVert,
        fragmentShader: fadeFrag
      }),
    [fadeSpeed]
  )

  const fboPlane = useMemo(() => {
    const geometry = new PlaneGeometry(1, 1) // Fullscreen quad in NDC

    return new Mesh(geometry, new ShaderMaterial())
  }, [])

  useFrame(() => {
    if (!ref.current) return

    if (!isDragging) {
      let match

      if (isHover) {
        match = picker.pick(pointer.state, ref.current)

        if (match) {
          pointer.state.mapX = match.x
          pointer.state.mapY = match.y
        }
      }

      if (!match) {
        match = OFFSCREEN
      }

      paintMaterial.uniforms.uPointer.value.set(match.x, match.y)
    }

    fboPlane.scale.set(viewport.width, viewport.height, 1)

    // Render painting pass
    fboPlane.material = paintMaterial
    gl.setRenderTarget(fbo1.current)
    paintMaterial.uniforms.uTexture.value = fbo2.current.texture
    gl.render(fboPlane, camera)

    // Render fading pass
    fboPlane.material = fadeMaterial
    gl.setRenderTarget(fbo2.current)
    fadeMaterial.uniforms.uTexture.value = fbo1.current.texture
    gl.render(fboPlane, camera)

    // Reset render target
    gl.setRenderTarget(null)
  })

  useEffect(() => {
    if (isDragging) {
      let match = picker.pick(pointer.state, ref.current)

      if (!match) {
        match = OFFSCREEN
      } else {
        pointer.state.mapX = match.x
        pointer.state.mapY = match.y
      }

      paintMaterial.uniforms.uPointer.value.set(match.x, match.y)
    }

    paintMaterial.uniforms.uPress.value = isDragging ? 1 : 0
  }, [isDragging])

  return <mesh ref={ref}>{children}</mesh>
})

// Display name
PointerPainter.displayName = 'PointerPainter'

PointerPainter.propTypes = {
  brushRadius: PropTypes.number,
  children: PropTypes.node,
  displacementMap: PropTypes.object,
  displacementScale: PropTypes.number,
  fadeSpeed: PropTypes.number,
  fbo: PropTypes.shape(),
  isDragging: PropTypes.bool,
  isHover: PropTypes.bool,
  onEnter: PropTypes.func,
  onLeave: PropTypes.func,
  onPress: PropTypes.func,
  onRelease: PropTypes.func,
  size: PropTypes.shape(),
  pointer: PropTypes.shape()
}

export default memo(PointerPainter)
