/*
 * Created on 24-jul-2008
 */
package be.SIRAPRISE.messages;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.security.InvalidKeyException;
import java.security.PublicKey;
import java.security.Signature;
import java.security.SignatureException;

import javax.crypto.BadPaddingException;
import javax.crypto.IllegalBlockSizeException;

import be.SIRAPRISE.client.CommunicationProtocolException;
import be.SIRAPRISE.client.ErrorMessageException;
import be.SIRAPRISE.client.Signer;
import be.SIRAPRISE.security.ProprietaryOrJCECipher;
import be.SIRAPRISE.util.MyDataInputStream;
import be.erwinsmout.NotFoundException;

/**
 * The ServerMessage class is an abstract class defining all possible messages sent to/by the SIRA_PRISE server
 * 
 * @author Erwin Smout
 */
public abstract class ServerMessage {

	/**
	 * The method used by readers to obtain a (structured) message from the given stream.
	 * 
	 * @param in
	 *            The inputstream from which the ServerMessage to be returned is to be constructed
	 * @param cryptoProtocol
	 *            The protocol to be used for decryption of what arrives after the byte count
	 * @param signingProtocol
	 *            the Signature used to verify the message signature
	 * @param publicKey
	 *            the public key used to verify the signature
	 * @return The ServerMessage read from the given stream
	 * @throws IOException
	 * @throws ErrorMessageException
	 * @throws CommunicationProtocolException
	 */
	public static ServerMessage readMessage (DataInputStream in, ProprietaryOrJCECipher cryptoProtocol, Signature signingProtocol, PublicKey publicKey) throws IOException, ErrorMessageException, CommunicationProtocolException {
		int byteCount = 0;
		// boolean noResponse = true;
		// do {
		// try {
		byteCount = in.readInt();
		if (byteCount < 0) {
			throw new IOException(Messages.getString("ServerMessage.NegativeByteCount")); //$NON-NLS-1$
		}
		// noResponse = false;
		// } catch (EOFException e1) {
		//				
		// }
		// } while (noResponse);

		byte[] inputBytes = MyDataInputStream.readExactNumberOfBytes(in, byteCount);

		if (cryptoProtocol != null) {
			try {
				inputBytes = cryptoProtocol.decrypt(inputBytes);
			} catch (InvalidKeyException e1) {
				throw new CommunicationProtocolException(Messages.getString("ServerMessage.DecryptionFailed"), e1); //$NON-NLS-1$
			} catch (IllegalBlockSizeException e1) {
				throw new CommunicationProtocolException(Messages.getString("ServerMessage.DecryptionFailed"), e1); //$NON-NLS-1$
			} catch (BadPaddingException e1) {
				throw new CommunicationProtocolException(Messages.getString("ServerMessage.DecryptionFailed"), e1); //$NON-NLS-1$
			}
		}
		DataInputStream localInputStream = new DataInputStream(new ByteArrayInputStream(inputBytes));

		if (signingProtocol != null) {
			try {
				short signatureLength = localInputStream.readShort();
				byte[] signature = new byte[signatureLength];
				localInputStream.read(signature);

				int actualMessageLength = localInputStream.readInt();
				byte[] actualMessage = MyDataInputStream.readExactNumberOfBytes(localInputStream, actualMessageLength);
				localInputStream = new DataInputStream(new ByteArrayInputStream(actualMessage));

				signingProtocol.initVerify(publicKey);
				signingProtocol.update(actualMessage);
				if (!signingProtocol.verify(signature)) {
					throw new CommunicationProtocolException(Messages.getString("ServerMessage.VerificationFailed"), (Exception) null); //$NON-NLS-1$
				}
			} catch (InvalidKeyException e1) {
				throw new CommunicationProtocolException(Messages.getString("ServerMessage.VerificationFailed"), e1); //$NON-NLS-1$
			} catch (SignatureException e) {
				throw new CommunicationProtocolException(Messages.getString("ServerMessage.VerificationFailed"), e); //$NON-NLS-1$
			}
		}

		// always get the message type and version first
		int messageType = localInputStream.readInt();
		int fullMessageVersion = localInputStream.readInt();

		ServerMessageType serverMessageType;
		try {
			serverMessageType = ServerMessageTypes.getInstance().getServerMessageTypeForMessageVersion(messageType, fullMessageVersion);
		} catch (NotFoundException e) {
			throw new CommunicationProtocolException(messageType, fullMessageVersion);
		}

		ServerMessage serverMessage = serverMessageType.typeSpecificFromStream(localInputStream);

		// System.out.println("Message Type read is "+serverMessageType.getClass().getName());

		if (serverMessageType instanceof ErrorMessageType) {
			throw new ErrorMessageException((ServerErrorMessage) serverMessage);
		}
		return serverMessage;
	}

	/**
	 * The MessageType object representing the type (and version of type) this new Message is of
	 */
	private ServerMessageType type;

	/**
	 * Creates a ServerMessage, setting the type
	 * 
	 * @param type
	 *            The MessageType object representing the type (and version of type) this new Message is of
	 */
	ServerMessage (ServerMessageType type) {
		this.type = type;
	}

	/**
	 * Gets the message type identification number
	 * 
	 * @return the message type identification number
	 */
	public final int getMessageTypeID ( ) {
		return type.getMessageTypeID();
	}

	/**
	 * Sends this message to the given output stream using the given signing and encryption settings.
	 * 
	 * @param out
	 *            The outputStream to which the message is to be written
	 * @param signingProtocol
	 *            The Signature object needed to compute the signature for the message, or null if no signing is required.
	 * @param signer
	 *            The Signer object that will compute the message signature using the given signingProtocol
	 * @param cryptoProtocol
	 *            The encryption object that will computed the encrypted message if encryption is required, null if no encryption is required.
	 * @throws CommunicationProtocolException
	 *             <ul>
	 *             <li>If any resource needed for computing a signature for this message (such as e.g. a private key), could not be found or turns out to be invalid</li>
	 *             <li>If any problem occurred during computation of the signature</li>
	 *             </ul>
	 * @throws IOException
	 */
	public final void sendMessage (DataOutputStream out, Signature signingProtocol, Signer signer, ProprietaryOrJCECipher cryptoProtocol) throws IOException, CommunicationProtocolException {

		// System.out.println("Sending Message Type " + getClass().getName());

		ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(256);
		DataOutputStream dataOutputStream = new DataOutputStream(byteArrayOutputStream);
		dataOutputStream.writeInt(type.getMessageTypeID());
		dataOutputStream.writeShort(type.getMajorVersion());
		dataOutputStream.writeShort(type.getMinorVersion());
		type.typeSpecificToStream(this, dataOutputStream);

		// If signing required, sign
		ByteArrayOutputStream signedByteArrayOutputStream;
		if (signingProtocol != null) {
			signedByteArrayOutputStream = new ByteArrayOutputStream(byteArrayOutputStream.size());
			try {
				// Write the signature first (=length + actual signature), only then the actual message
				byte[] signMessage = byteArrayOutputStream.toByteArray();
				byte[] signature = signer.sign(signingProtocol, signMessage);
				DataOutputStream signedDataOutputStream = new DataOutputStream(signedByteArrayOutputStream);
				signedDataOutputStream.writeShort(signature.length);
				signedDataOutputStream.write(signature);
				// Now the actual message (=length + message)
				signedDataOutputStream.writeInt(byteArrayOutputStream.size());
				signedDataOutputStream.write(signMessage);
			} catch (InvalidKeyException e) {
				throw new CommunicationProtocolException(Messages.getString("ServerMessageType.SigningFailed"), e); //$NON-NLS-1$
			} catch (SignatureException e) {
				throw new CommunicationProtocolException(Messages.getString("ServerMessageType.SigningFailed"), e); //$NON-NLS-1$
			} catch (NotFoundException e) {
				throw new CommunicationProtocolException(Messages.getString("ServerMessageType.SigningFailed"), e); //$NON-NLS-1$
			}
		} else {
			signedByteArrayOutputStream = byteArrayOutputStream;
		}

		// If encryption required, encrypt
		ByteArrayOutputStream byteArrayOutputStreamToBeWritten;
		if (cryptoProtocol != null) {
			byteArrayOutputStreamToBeWritten = new ByteArrayOutputStream();
			try {
				byteArrayOutputStreamToBeWritten.write(cryptoProtocol.encrypt(signedByteArrayOutputStream.toByteArray()));
			} catch (InvalidKeyException e) {
				throw new CommunicationProtocolException(Messages.getString("ServerMessageType.EncryptionFailed"), e); //$NON-NLS-1$
			} catch (IllegalBlockSizeException e) {
				throw new CommunicationProtocolException(Messages.getString("ServerMessageType.EncryptionFailed"), e); //$NON-NLS-1$
			} catch (BadPaddingException e) {
				throw new CommunicationProtocolException(Messages.getString("ServerMessageType.EncryptionFailed"), e); //$NON-NLS-1$
			}
		} else {
			byteArrayOutputStreamToBeWritten = signedByteArrayOutputStream;
		}

//		System.out.println(byteArrayOutputStreamToBeWritten.toString() + "/" + MyByteBuffer.toHex(ByteBuffer.wrap(byteArrayOutputStreamToBeWritten.toByteArray())));

		out.writeInt(byteArrayOutputStreamToBeWritten.size());
		byteArrayOutputStreamToBeWritten.writeTo(out);
		out.flush();
	}
}