// Copyright 1997-1998 Omni Development, Inc.  All rights reserved.
//
// This software may only be used and reproduced according to the
// terms in the file OmniSourceLicense.html, which should be
// distributed with this project and can also be found at
// http://www.omnigroup.com/DeveloperResources/OmniSourceLicense.html.

#import "ONTCPSocket.h"

#import <Foundation/Foundation.h>
#import <OmniBase/OmniBase.h>
#import <OmniBase/system.h>

#import "ONHost.h"
#import "ONHostAddress.h"
#import "ONInternetSocket-Private.h"
#import "ONServiceEntry.h"

RCS_ID("$Header: /Network/Developer/Source/CVS/OmniGroup/OmniNetworking/ONTCPSocket.m,v 1.15 1998/12/08 04:08:44 kc Exp $")

@interface ONTCPSocket (Private)
- (int)socketFDForAcceptedConnection;
@end

@implementation ONTCPSocket

static Class defaultTCPSocketClass = nil;

+ (void)initialize;
{
    static BOOL alreadyInitialized = NO;

    [super initialize];
    if (alreadyInitialized)
        return;
    alreadyInitialized = YES;
    defaultTCPSocketClass = [ONTCPSocket class];
}

+ (Class)defaultTCPSocketClass;
{
    return defaultTCPSocketClass;
}

+ (void)setDefaultTCPSocketClass:(Class)aClass;
{
    // TODO: ASSERT that aClass is a subclass of ONTCPSocket
    defaultTCPSocketClass = aClass;
}

+ (ONTCPSocket *)tcpSocket;
{
    return (ONTCPSocket *)[defaultTCPSocketClass socket];
}

//

- (void)startListeningOnAnyLocalPort;
{
    [self startListeningOnLocalPort:0];
}

- (void)startListeningOnLocalPort:(unsigned short int)port;
{
    [self setLocalPortNumber:port];

    if (listen(socketFD, 5) == -1)
	[NSException raise:ONTCPSocketListenFailedExceptionName format:@"Unable to listen on socket: %s", strerror(OMNI_ERRNO())];
    flags.listening = YES;
}

- (void)startListeningOnLocalService:(ONServiceEntry *)service;
{
    [self startListeningOnLocalPort:[service portNumber]];
}

- (void)acceptConnection;
{
    int newSocketFD;
    BOOL socketCloseSucceeded;

    newSocketFD = [self socketFDForAcceptedConnection];
    socketCloseSucceeded = OBSocketClose(socketFD) == 0;
    OBASSERT(socketCloseSucceeded);
    socketFD = newSocketFD;
    flags.connected = YES;
    flags.listening = NO;
}

- (ONTCPSocket *)acceptConnectionOnNewSocket;
{
    return [[[isa alloc] _initWithSocketFD:[self socketFDForAcceptedConnection] connected:YES] autorelease];
}

// ONInternetSocket subclass

+ (int)socketType;
{
    return SOCK_STREAM;
}

+ (int)protocol;
{
    return IPPROTO_TCP;
}

- (unsigned int)readBytes:(unsigned int)byteCount intoBuffer:(void *)aBuffer;
{
    int bytesRead;

    while (!flags.connected) {
	if (!flags.listening)
	    [NSException raise:ONInternetSocketNotConnectedExceptionName format:@"Attempted read from a non-connected socket"];
	else
	    [self acceptConnection];
    }
    bytesRead = OBSocketRead(socketFD, aBuffer, byteCount);
    if (bytesRead == -1)
	[NSException raise:ONInternetSocketReadFailedExceptionName format:@"Unable to read from socket: %s", strerror(OMNI_ERRNO())];
    return (unsigned int)bytesRead;
}

- (unsigned int)writeBytes:(unsigned int)byteCount fromBuffer:(const void *)aBuffer;
{
    int bytesWritten;

    while (!flags.connected) {
	if (!flags.listening)
	    [NSException raise:ONInternetSocketNotConnectedExceptionName
	     format:@"Attempted write to a non-connected socket"];
	else
	    [self acceptConnection];
    }
#ifndef MAX_BYTES_PER_WRITE
    bytesWritten = OBSocketWrite(socketFD, aBuffer, byteCount);
#else
    bytesWritten = OBSocketWrite(socketFD, aBuffer, byteCount > MAX_BYTES_PER_WRITE ? MAX_BYTES_PER_WRITE : byteCount);
#endif
    if (bytesWritten == -1)
	[NSException raise:ONInternetSocketWriteFailedExceptionName format:@"Unable to write to socket: %s", strerror(OMNI_ERRNO())];
    return (unsigned int)bytesWritten;
}

@end

@implementation ONTCPSocket (Private)

- (int)socketFDForAcceptedConnection;
{
    int newSocketFD;
    struct sockaddr_in acceptAddress;
    int acceptAddressLength;

    acceptAddressLength = sizeof(struct sockaddr_in);
    do {
	newSocketFD = accept(socketFD, (struct sockaddr *)&acceptAddress, &acceptAddressLength);
    } while (newSocketFD == -1 && OMNI_ERRNO() == EINTR);

    if (newSocketFD == -1)
	[NSException raise:ONTCPSocketAcceptFailedExceptionName format:@"Socket accept failed: %s", strerror(OMNI_ERRNO())];
    if (!remoteAddress)
        remoteAddress = NSZoneMalloc(NULL, sizeof(struct sockaddr_in));
    *remoteAddress = acceptAddress;
    return newSocketFD;
}

@end


DEFINE_NSSTRING(ONTCPSocketListenFailedExceptionName);
DEFINE_NSSTRING(ONTCPSocketAcceptFailedExceptionName);
