import * as THREE from 'three';
import { STLExporter, STLExporterOptions, STLExporterOptionsBinary, STLExporterOptionsString } from 'three/examples/jsm/exporters/STLExporter.js';

export class ColoredSTLExporter extends STLExporter {
    parse(scene: THREE.Object3D, options: STLExporterOptionsBinary): DataView;
    parse(scene: THREE.Object3D, options?: STLExporterOptionsString): string;
    parse(scene: THREE.Object3D, options: STLExporterOptions = {}): string | DataView {
        // Force binary mode for color support
        options.binary = true;

        const vector = new THREE.Vector3();
        const normalMatrix = new THREE.Matrix3();

        let triangleCount = 0;
        let offset = 0;

        const buffers: ArrayBuffer[] = [];
        const materials: THREE.Material[] = [];

        scene.traverse((object) => {
            if (!(object instanceof THREE.Mesh)) return;

            const geometry = object.geometry;
            const material = object.material;
            const matrixWorld = object.matrixWorld;

            if (geometry instanceof THREE.BufferGeometry) {
                const vertices = geometry.getAttribute('position');
                const indices = geometry.getIndex();
                const normals = geometry.getAttribute('normal');

                normalMatrix.getNormalMatrix(matrixWorld);

                // If no indices, create them
                const vertexCount = vertices.count;
                const positions = vertices.array;
                const triangles = indices ? indices.array : null;

                const processTriangle = (i1: number, i2: number, i3: number) => {
                    const buffer = new ArrayBuffer(50); // 12 bytes normal + 3 * 12 bytes vertices + 2 bytes attribute
                    const view = new DataView(buffer);
                    offset = 0;

                    // Get vertices of the triangle
                    const vx1 = new THREE.Vector3().fromBufferAttribute(vertices, i1);
                    const vx2 = new THREE.Vector3().fromBufferAttribute(vertices, i2);
                    const vx3 = new THREE.Vector3().fromBufferAttribute(vertices, i3);

                    // Transform vertices to world space
                    vx1.applyMatrix4(matrixWorld);
                    vx2.applyMatrix4(matrixWorld);
                    vx3.applyMatrix4(matrixWorld);

                    // Calculate normal
                    const normal = new THREE.Vector3()
                        .crossVectors(
                            new THREE.Vector3().subVectors(vx2, vx1),
                            new THREE.Vector3().subVectors(vx3, vx1)
                        )
                        .normalize();

                    // Write normal
                    view.setFloat32(offset, normal.x, true); offset += 4;
                    view.setFloat32(offset, normal.y, true); offset += 4;
                    view.setFloat32(offset, normal.z, true); offset += 4;

                    // Write vertices
                    view.setFloat32(offset, vx1.x, true); offset += 4;
                    view.setFloat32(offset, vx1.y, true); offset += 4;
                    view.setFloat32(offset, vx1.z, true); offset += 4;
                    view.setFloat32(offset, vx2.x, true); offset += 4;
                    view.setFloat32(offset, vx2.y, true); offset += 4;
                    view.setFloat32(offset, vx2.z, true); offset += 4;
                    view.setFloat32(offset, vx3.x, true); offset += 4;
                    view.setFloat32(offset, vx3.y, true); offset += 4;
                    view.setFloat32(offset, vx3.z, true); offset += 4;

                    // Write color in attribute byte count
                    // Get color from material
                    let color = new THREE.Color(0xffffff);
                    if (material instanceof THREE.MeshPhongMaterial ||
                        material instanceof THREE.MeshStandardMaterial ||
                        material instanceof THREE.MeshBasicMaterial) {
                        color = material.color;
                    }

                    // Convert color to 5-5-5 RGB format
                    const r = Math.floor(color.r * 31);
                    const g = Math.floor(color.g * 31);
                    const b = Math.floor(color.b * 31);
                    const colorValue = (r << 10) | (g << 5) | b;

                    // Write attribute byte count with color
                    view.setUint16(offset, colorValue | 0x8000, true);

                    buffers.push(buffer);
                    triangleCount++;
                };

                if (indices) {
                    // Indexed geometry
                    for (let i = 0; i < indices.count; i += 3) {
                        processTriangle(
                            indices.getX(i),
                            indices.getX(i + 1),
                            indices.getX(i + 2)
                        );
                    }
                } else {
                    // Non-indexed geometry
                    for (let i = 0; i < vertices.count; i += 3) {
                        processTriangle(i, i + 1, i + 2);
                    }
                }
            }
        });

        // Write header
        const headerArrayBuffer = new ArrayBuffer(80);
        const header = new Uint8Array(headerArrayBuffer);
        const encoder = new TextEncoder();
        const headerText = 'Binary STL File with Colors - Generated by SheetBuild';
        encoder.encodeInto(headerText, header);

        // Create final buffer
        const totalSize = 80 + 4 + (50 * triangleCount);
        const finalBuffer = new ArrayBuffer(totalSize);
        const finalView = new DataView(finalBuffer);

        // Copy header
        new Uint8Array(finalBuffer, 0, 80).set(header);

        // Write triangle count
        finalView.setUint32(80, triangleCount, true);

        // Copy all triangle buffers
        let finalOffset = 84;
        for (const buffer of buffers) {
            new Uint8Array(finalBuffer, finalOffset, 50).set(new Uint8Array(buffer));
            finalOffset += 50;
        }

        return new DataView(finalBuffer);
    }
} 