package ssh;

import ssh.RSA.*;
import Tools.*;
import jp.*;
import proof.*;
import sdsi.*;

import java.io.*;
import java.net.*;
import java.util.*;
import java.math.*;

/**
 * This class implements a simple encrypted channel, based on
 * the ssh protocol. This class actually implements both halves; a
 * SSHServerSocket is just a thing that accept()s requests and creates
 * one of these SSHSockets in server mode to handle the server side of
 * a connection.<p>
 *
 * Several features are left out (man-in-the-middle and privacy
 * defenses); this is okay for my purposes, since I implement and
 * verify those services in a higher layer based on my
 * restricted-delegation logic.<p>
 * 
 * Source: based on ClientProtocol.java, my Java implementation of ssh
 * 1.5.
 * 
 * @author Jon Howell <jonh@cs.dartmouth.edu>
 * 
 * @license This code is Copyright 1999 Jon Howell. It is available for
 * use under the GNU Public License, available at:
 * http://www.gnu.org/copyleft/gpl.html
 * 
 * @rcs $Id: SSHSocket.java,v 1.5 2000/05/22 01:35:59 jonh Exp $
 */

public class SSHSocket
	extends Socket
	implements KeyedSocket, SRPConstants {

	public static String protocolVersion = "SRP-1.0-1.0.0\n";
		// Protocol string is like ssh protocol string.
		// 'SRP' identifies the protocol (what I once called
		// Secure RMI Protocol, as this was mainly to be used in RMI);
		// 1.0 identifies the protocol version,
		// 1.0.0 identifies the client software version.
		// (sort of useless if there are multiple implementations.)

	/**
	 * This constructor is used when SSHServerSocket wants to create a
	 * new stream socket to handle an incoming connection. It's not
	 * used by clients to create new outbound connections.
	 */
	SSHSocket(SSHContext context) {
		this.context = context;
		sshIn = null;
		sshOut = null;
		oppositeKey = null;
		serverSocket = true;
	}

	/**
	 * Initiates a client-end connection to a remote server.
	 * ("client-end" just means that we run the client's end of the
	 * protocol.)
	 */
	public SSHSocket(SSHContext context,
		InetAddress remoteAddress, int remotePort,
		InetAddress localAddress, int localPort)
		throws IOException {
		super(remoteAddress, remotePort, localAddress, localPort);
		this.context = context;
		initClient();
	}

	public SSHSocket(SSHContext context,
		String remoteHost, int remotePort)
		throws IOException {
		super(remoteHost, remotePort);
		this.context = context;
		initClient();
	}

	public SSHSocket(SSHContext context,
		InetAddress remoteAddress, int remotePort)
		throws IOException {
		this(context, remoteAddress, remotePort, null, 0);
	}

	/**
	 * How to find out what public key identifies the other end of this
	 * connection.
	 */
	public RSAKey getOppositeKey() {
		return oppositeKey;
	}

	public void initClient()
		throws IOException {
		socketIn = super.getInputStream();
		socketOut = super.getOutputStream();

//		log.setPrefix("initClient: ").log("verbose", "exchangeVersion");
//	Timer timer = new Timer();
		exchangeVersionIdentification();
		binaryIn = new BinaryPacketInputStream(socketIn);
		binaryOut = new BinaryPacketOutputStream(socketOut);
		// log.log("verbose", "key exchange");
//		Timer t = new Timer();
		clientKeyExchange();	// prove we each have our respective keys;
								// establish session key
//		System.err.println("SSHSocket client key exchange: "+t);

		// these are the streams a higher-level protocol will use to
		// actually transmit data.
		sshIn = new SshInputStream(binaryIn);
		sshOut = new SshOutputStream(binaryOut);
//	System.out.println("client key exchange done; time = "+timer);
		// log.log("verbose", "socket is connected to "+getInetAddress()+", port "+getPort());

		// log.log("keys", "SSH key at my end: " +Prover2.staticGetName(new SDSIRSAPublicKey(context.publicKey)));
		// log.log("keys", "SSH key at other end: " +Prover2.staticGetName(new SDSIRSAPublicKey(oppositeKey)));
	}

	/**
	 * This one gets called by SSHServerSocket to initialize an
	 * instance of SSHSocket that's doing the server half of the
	 * protocol.
	 */
	void initServer()
		throws IOException {
		socketIn = super.getInputStream();
		socketOut = super.getOutputStream();

		log.setPrefix("initServer: ").log("verbose", "exchangeVersion");
		exchangeVersionIdentification();
		binaryIn = new BinaryPacketInputStream(socketIn);
		binaryOut = new BinaryPacketOutputStream(socketOut);
		// log.log("verbose", "key exchange");
		serverKeyExchange();	// prove we each have our respective keys;
								// establish session key

		// these are the streams a higher-level protocol will use to
		// actually transmit data.
		sshIn = new SshInputStream(binaryIn);
		sshOut = new SshOutputStream(binaryOut);
		// log.log("verbose", "done");

		// log.log("keys", "SSH key at my end: " +Prover2.staticGetName(new SDSIRSAPublicKey(context.publicKey)));
		// log.log("keys", "SSH key at other end: " +Prover2.staticGetName(new SDSIRSAPublicKey(oppositeKey)));
	}

	// these become meaningful once connect has returned successfully
	public InputStream getInputStream() {
		callersKey.set(getOppositeKey());
//		// log.log("keys", "Set caller key to "
//			+Prover2.staticGetName(new SDSIRSAPublicKey(getOppositeKey()))
//			+" for thread "+Thread.currentThread().hashCode());
		return sshIn;
	}

	public OutputStream getOutputStream() {
		return sshOut;
	}

	/**
	 * Make sure encrypted stream gets flushed cleanly.
	 */
	public void close()
		throws IOException {
		sshOut.close();
		sshIn.close();
		super.close();
	}

	////////////////////////////////////////////////////////////////////////
	//  fields  ////////////////////////////////////////////////////////////
	////////////////////////////////////////////////////////////////////////
	/**
	 * The public/private keys for our end of the connection, plus a
	 * random-number generator.
	 */
	SSHContext context;

	/**
	 * Underlying stream (data is encrypted at this level)
	 */
	InputStream socketIn;
	OutputStream socketOut;

	/**
	 * Streams that package data into packets to be encrypted
	 */
	BinaryPacketInputStream binaryIn;
	BinaryPacketOutputStream binaryOut;

	/**
	 * Caller-visible streams that look like a normal stream (but get
	 * encrypted underneath on the socketIn/socketOut streams).
	 */
	SshInputStream sshIn;
	SshOutputStream sshOut;

	/**
	 * key received for remote end of connection
	 */
	RSAKey oppositeKey;

	boolean serverSocket = false;

	/**
	 * Maps a thread to an RSAKey. Used by recipients of RMI calls to
	 * determine the public key associated with the channel over which
	 * the RMI call arrived. The socket knows the public key of the
	 * remote end, since he did the key checking. So the socket
	 * establishes this mapping on getInputStream(), which gets called
	 * by the RMI subsystem when opening the socket, before (but in the
	 * same thread) as dispatching the RMI call to the remote object
	 * implementation. The remote object can then retrieve the mapping
	 * (knowing only his thread ID) to determine what public key was
	 * associated with the Socket that the call came in on.
	 *
	 * There is a RISK that, if some untrusted code runs in this JVM,
	 * it could fiddle around with this mapping to lie about who did
	 * what "saying." So we would want to think about that.
	 */
	static PerThread callersKey = new PerThread();

	static Log log = new Log();
	static {
		// log.addLevel("main");
		// log.addLevel("verbose");
		// log.addLevel("packets");
		// log.addLevel("borrow");
		log.addLevel("keys");
			// info about the borrowed-session-key shortcut
	}

	static boolean borrowingAllowed = true;
	public static void setBorrowingAllowed(boolean state) {
		borrowingAllowed = state;
	}

	/**
	 * If you are a remote object implementation, you may call this to
	 * learn the RSAKey public key identity that authenticated the
	 * calling end of this socket. That is, in speaks-for terms, the
	 * principal returned by whoCalledMe() "says"
	 * remoteMethod(arguments...).
	 * 
	 * If you call this from another method, be aware of what thread
	 * you're in. This call does its dirty work by matching the current
	 * thread with the Thread that "answered" the incoming Socket
	 * connection. So if you might be on the other side of a queue (in
	 * a different Thread) than the original RMI call, this call may
	 * return null, or worse yet, a meaningless key.
	 */
	public static RSAKey whoCalledMe() {
//		// log.log("keys", "whoCalledMe() = "
//			+Prover2.staticGetName(new SDSIRSAPublicKey((RSAKey) callersKey.get()))
//			+" for thread "+Thread.currentThread().hashCode());
		return (RSAKey) callersKey.get();
	}

	/**
	 * Maps hashCode(md5(sessionKey)) => sessionKey
	 */
	static HashMap sessionInfoCache = new HashMap();

	/**
	 * Maps a (server,port) tuple to a hashCode(md5(sessionKey)), which
	 * we use to request the use of the sessionKey from the server.
	 */
	static HashMap destinationToKey = new HashMap();

	////////////////////////////////////////////////////////////////////////
	//  private methods  ///////////////////////////////////////////////////
	////////////////////////////////////////////////////////////////////////

	void exchangeVersionIdentification()
		throws IOException {

		// send our version number to server
		socketOut.write(protocolVersion.getBytes());

		// read characters until a \r or \n. Don't read past that!
		// after that, the protocol switches to ssh's binary packet protocol
		StringBuffer sb = new StringBuffer();
		while (true) {
			int ci = socketIn.read();
			if (ci<0) throw new EOFException();
			char c = (char) ci;
			if (c == '\r') {
				c = '\n';
			}
			sb.append(c);
			if (c == '\n') {
				break;
			}
		}

		String remoteVersion = sb.toString();
		StringTokenizer tok = new StringTokenizer(remoteVersion, "-");
		String constantSRP = tok.nextToken();
		String protocolVersion = tok.nextToken();
		String softwareVersion = tok.nextToken();
		if (!constantSRP.equals("SRP")) {
			throw new IOException("Remote version string is bad: "
				+remoteVersion);
		}

		StringTokenizer tok2 = new StringTokenizer(protocolVersion, ".");
		int protMajor = Integer.parseInt(tok2.nextToken());
		int protMinor = Integer.parseInt(tok2.nextToken());
		if (protMajor != 1) {
			throw new IOException("Remote major version incompatible: "
				+protocolVersion);
		}
	}

	void clientKeyExchange()
		throws IOException {
		BinaryPacketOut bpo;
		SessionInfo sessionInfo = null;
		byte[] borrowSessionKey = null;

		// See whether we might like to take the borrowed-key shortcut
		Integer skhi = null;
		if (borrowingAllowed) {
			skhi = (Integer) destinationToKey.get(destinationKey());
		}
		if (skhi!=null) {
			sessionInfo = (SessionInfo) sessionInfoCache.get(skhi);
		}
		if (sessionInfo!=null) {
			log.log("borrow", "trying shortcut");
			bpo = binaryOut.newPacket();
			bpo.setType(SRP_CMSG_BORROW_SESSION_KEY);
			bpo.writeInt(skhi.intValue());
			binaryOut.writePacket(bpo);
		} else {
			// Send the server a packet saying we want to do a
			// full key exchange
			bpo = binaryOut.newPacket();
			bpo.setType(SRP_CMSG_KEY_EXCHANGE);
			binaryOut.writePacket(bpo);
		}

		// the first packet from the server is the server key
		BinaryPacketIn bpi = binaryIn.readPacket();
		if (sessionInfo!=null
			&& bpi.getType() == SRP_SMSG_SUCCESS) {
			installSessionInfo(sessionInfo);
			sendCMSG_SUCCESS();
			log.log("borrow", "borrow-key shortcut successful.");
			return;
		} else {
			sessionInfo = null;
		}

		Assert.assert(bpi.getType() == SRP_SMSG_SERVER_KEY);
		if (bpi.getType() != SRP_SMSG_SERVER_KEY) {
			throw new IOException("Unexpected packet type: "+bpi.getType()
				+"; expected "+SRP_SMSG_SERVER_KEY);
		}
		RSAKey serverKey				= RSAKey.readSsh(bpi);
		int protocolFlags				= bpi.readInt();
		int supportedCiphersMask		= bpi.readInt();
		BigInteger clientChallenge		= StreamExtras.readBigInteger(bpi);

		// take the challenge (prove we hold the private side of
		// the client key we're sending; prevents replay attack)
		BigInteger eClientChallenge
			= context.privateKey.encrypt(clientChallenge.toByteArray(),
				context.random);

		// create our own challenge for the server
		BigInteger serverChallenge = context.random.newBigInteger(32);

		// send SRP_CMSG_CLIENT_KEY packet.
		bpo = binaryOut.newPacket();
		bpo.setType(SRP_CMSG_CLIENT_KEY);
		context.publicKey.writeSsh(bpo);
		StreamExtras.writeBigInteger(bpo, eClientChallenge);
		StreamExtras.writeBigInteger(bpo, serverChallenge);
		bpo.writeByte(SRP_CIPHER_IDEA);
		binaryOut.writePacket(bpo);

		// receive SRP_SMSG_SESSION_KEY packet
		bpi = binaryIn.readPacket();
		Assert.assert(bpi.getType() == SRP_SMSG_SESSION_KEY);
		BigInteger eServerChallenge		= StreamExtras.readBigInteger(bpi);
		BigInteger eSessionKey			= StreamExtras.readBigInteger(bpi);
		int sessionKeyHash				= bpi.readInt();

		// verify server challenge was met
		byte[] dec = serverKey.decrypt(eServerChallenge);
		if (dec == null) {
			throw new IOException(
				"server's response does not decrypt to challenge");
		}
		BigInteger serverChallengeReply = new BigInteger(dec);
		if (!serverChallengeReply.equals(serverChallenge)) {
			throw new IOException(
				"server's response does not decrypt to challenge");
		}

		// decrypt session key (for my eyes only!)
		BigInteger sessionKeyBI =
			new BigInteger(context.privateKey.decrypt(eSessionKey));
		byte[] sessionKey = sessionKeyBI.toByteArray();

		sessionInfo = new SessionInfo();
		sessionInfo.oppositeKey = serverKey;
		sessionInfo.sessionKey = sessionKey;
		installSessionInfo(sessionInfo);

		sendCMSG_SUCCESS();

		// Cache the session key for reuse in sessions with the same
		// endpoints to avoid extra key exchanges. (see comments & caveats
		// in SRPProtocol.java)
		skhi = new Integer(sessionKeyHash);
		sessionInfoCache.put(skhi, sessionInfo);
		destinationToKey.put(destinationKey(), skhi);
	}

	void sendCMSG_SUCCESS()
		throws IOException {
		// send success message to indicate (indeed, demonstrate)
		// that the session key is good.
		BinaryPacketOut bpo = binaryOut.newPacket();
		bpo.setType(SRP_CMSG_SUCCESS);
		binaryOut.writePacket(bpo);
	}

	Object destinationKey() {
		HashKey hk = new HashKey();
		hk.add(getInetAddress());
		hk.add(new Integer(getPort()));
		return hk;
	}

	void serverKeyExchange()
		throws IOException {

		/**
		 * relevant secrets about session: client public key; secret
		 * session key
		 */
		SessionInfo sessionInfo;
		BinaryPacketIn bpi;
		BinaryPacketOut bpo;

		// Wait for the client to possibly suggest the use of a borrowed
		// session key from another session.
		bpi = binaryIn.readPacket();
		if (bpi.getType() == SRP_CMSG_BORROW_SESSION_KEY) {
			// If borrowed key is acceptable, send SRP_SMSG_SUCCESS
			// to indicate the negotiation is complete.
			Integer skhi = new Integer(bpi.readInt());
			sessionInfo = (SessionInfo) sessionInfoCache.get(skhi);
			if (sessionInfo==null) {
				// log.log("borrow", "Client wants to borrow session key " +"with hash "+skhi+" but I don't have it.");
			} else {
				// log.log("borrow", "session key available; going along " +"with shortcut.");
				// Tell server everything is copacetic
				bpo = binaryOut.newPacket();
				bpo.setType(SRP_SMSG_SUCCESS);
				binaryOut.writePacket(bpo);
				// install the session key and look for CMSG_SUCCESS
				installSessionInfo(sessionInfo);
				receiveCMSG_SUCCESS();
				// exchange done! Whew, that was fast.
				return;
			}
		} else if (bpi.getType() == SRP_CMSG_KEY_EXCHANGE) {
			// Else fall through and negotiate a session key using
			// public keys, the normal way.
		} else {
			throw new IOException("SSH Protocol error: unexpected packet type "
				+bpi.getType()+"; expected "+SRP_CMSG_BORROW_SESSION_KEY
				+" or "+SRP_CMSG_KEY_EXCHANGE+".");
		}

		// create a challenge for the client
		BigInteger clientChallenge = context.random.newBigInteger(32);

		// the first packet from the server is the server key
		bpo = binaryOut.newPacket();
		bpo.setType(SRP_SMSG_SERVER_KEY);
		context.publicKey.writeSsh(bpo);
		bpo.writeInt(0);						// protocolFlags
		bpo.writeInt(SRP_CIPHER_IDEA);			// supportedCipherMask
		StreamExtras.writeBigInteger(bpo, clientChallenge);
		binaryOut.writePacket(bpo);

		// receive client's response and challenge for us
		bpi = binaryIn.readPacket();
		Assert.assert(bpi.getType() == SRP_CMSG_CLIENT_KEY);
		RSAKey clientKey				= RSAKey.readSsh(bpi);
		BigInteger eClientChallenge		= StreamExtras.readBigInteger(bpi);
		BigInteger serverChallenge		= StreamExtras.readBigInteger(bpi);
		byte selectedCipher				= bpi.readByte();

		// verify clientChallenge was responded to successfully
		byte[] dec = clientKey.decrypt(eClientChallenge);
		if (dec == null) {
			throw new IOException(
				"client's response does not decrypt to challenge");
		}
		BigInteger clientChallengeReply = new BigInteger(dec);
		if (!clientChallengeReply.equals(clientChallenge)) {
			throw new IOException(
				"client's response does not decrypt to challenge");
		}

		// ensure cipher selection is okay
		Assert.assert(selectedCipher == SRP_CIPHER_IDEA);

		// take the server challenge (prove we hold the private side of
		// the server key we're sending; prevents replay attack)
		BigInteger eServerChallenge
			= context.privateKey.encrypt(serverChallenge.toByteArray(),
				context.random);

		// create session key
		byte[] sessionKey
			= context.random.newByteArray(32);	// 32 bytes of randomness
		BigInteger eSessionKey = clientKey.encrypt(sessionKey, context.random);
		MD5 md5 = new MD5();
		md5.update(sessionKey);
		int sessionKeyHash = md5.hashCode();

		// send session key packet.
		bpo = binaryOut.newPacket();
		bpo.setType(SRP_SMSG_SESSION_KEY);
		StreamExtras.writeBigInteger(bpo, eServerChallenge);
		StreamExtras.writeBigInteger(bpo, eSessionKey);
		bpo.writeInt(sessionKeyHash);
		binaryOut.writePacket(bpo);

		sessionInfo = new SessionInfo();
		sessionInfo.oppositeKey = clientKey;
		sessionInfo.sessionKey = sessionKey;
		installSessionInfo(sessionInfo);

		receiveCMSG_SUCCESS();

		// Cache the session key for reuse in sessions with the same
		// endpoints to avoid extra key exchanges. (see comments & caveats
		// in SRPProtocol.java)
		sessionInfoCache.put(new Integer(sessionKeyHash), sessionInfo);
	}

	void receiveCMSG_SUCCESS()
		throws IOException {
		// receive SRP_CMSG_SUCCESS packet
		BinaryPacketIn bpi = binaryIn.readPacket();
		Assert.assert(bpi.getType() == SRP_CMSG_SUCCESS);
	}

	void installSessionInfo(SessionInfo sessionInfo) {
		// need separate cipher objects on each stream, even though
		// the key is the same, because streams are chained independently.
		// (draft-ylonen page 5). [Makes sense, since packets are sent
		// asynchronously from each end.]
		Cipher cipher = new CipherIdea();
		cipher.setKey(sessionInfo.sessionKey);
		binaryOut.setCipher(cipher);

		cipher = new CipherIdea();
		cipher.setKey(sessionInfo.sessionKey);
		binaryIn.setCipher(cipher);

		oppositeKey = sessionInfo.oppositeKey;
	}

	class SessionInfo {
		RSAKey oppositeKey;
		byte[] sessionKey;
	}

	public static void main(String[] args) {
		try {
			Options opts = new Options(args) {
				public void defineOptions() {
		programName = "SSHSocket";
		defineOption("server", "Act as server", "false");
		defineOption("host", "Host to connect to/listen on", "localhost");
		defineOption("port", "Port number to connect to/listen on", "9876");
		defineOption("once", "Server exits after first connection", "true");
		defineOption("count", "Number of times client connects", "1");
				}
			};

			SSHContext context = SSHContext.getDefault();

			if (opts.getBoolean("server")) {
				// server side
				SSHServerSocket servSock =
					new SSHServerSocket(context, opts.getInt("port"));
				while (true) {
					try {
						// log.log("main", "listening for connection");
						SSHSocket connSock = (SSHSocket) servSock.accept();
						// log.log("main", "received connection");
						Writer w =
							new OutputStreamWriter(connSock.getOutputStream());
						w.write("This connection is ALIIIIVE!\n");
						w.flush();
						connSock.close();
						if (opts.getBoolean("once")) {
							break;
						}
					} catch (Exception ex) {
						ex.printStackTrace();
						System.err.println("continuing...");
					}
				}
				System.exit(0);
			} else {
				// client side
				for (int i=0; i<opts.getInt("count"); i++) {
					// log.log("main", "making connection");
					SSHSocket clientSock = new SSHSocket(context,
						opts.get("host"), opts.getInt("port"));
					// log.log("main", "connected; listening for message.");
					InputStream is = clientSock.getInputStream();
					int len = 256;
					byte[] buf = new byte[len];
					while (true) {
						int rc = is.read(buf);
						if (rc<=0) {
							// log.log("main", "socket closed");
							break;
						}
						System.out.write(buf, 0, rc);
					}
				}
			}
		} catch (Exception ex) {
			ex.printStackTrace();
		}
	}
}
