diff --git a/include/net/class.h b/include/net/class.h index 5e0ab59..3ca0c4d 100644 --- a/include/net/class.h +++ b/include/net/class.h @@ -2,6 +2,7 @@ #include "sys/types.h" #include "sys/list.h" +struct io_notify; struct sockaddr; struct socket; @@ -18,6 +19,7 @@ struct sockops { int (*accept) (struct socket *, struct socket *); int (*count_pending) (struct socket *); + struct io_notify *(*get_rx_notify) (struct socket *); }; struct socket_class { diff --git a/include/net/socket.h b/include/net/socket.h index 97df79a..0c9e403 100644 --- a/include/net/socket.h +++ b/include/net/socket.h @@ -10,7 +10,7 @@ struct netdev; struct socket { struct sockops *op; struct vfs_ioctx *ioctx; - struct io_notify rx_notify; + //struct io_notify rx_notify; void *data; }; @@ -38,3 +38,4 @@ int net_setsockopt(struct vfs_ioctx *ioctx, struct ofile *fd, int optname, void void net_close(struct vfs_ioctx *ioctx, struct ofile *fd); int socket_has_data(struct socket *sock); +struct io_notify *socket_get_rx_notify(struct socket *sock); diff --git a/net/socket.c b/net/socket.c index 26dc44f..1f03b98 100644 --- a/net/socket.c +++ b/net/socket.c @@ -18,6 +18,11 @@ int socket_has_data(struct socket *sock) { return !!sock->op->count_pending(sock); } +struct io_notify *socket_get_rx_notify(struct socket *sock) { + _assert(sock->op && sock->op->get_rx_notify); + return sock->op->get_rx_notify(sock); +} + int net_open(struct vfs_ioctx *ioctx, struct ofile *fd, int dom, int type, int proto) { struct socket_class *cls, *iter; @@ -44,7 +49,6 @@ int net_open(struct vfs_ioctx *ioctx, struct ofile *fd, int dom, int type, int p fd->flags = OF_SOCKET; fd->socket.ioctx = ioctx; fd->socket.op = cls->ops; - thread_wait_io_init(&fd->socket.rx_notify); return cls->ops->open(&fd->socket); } diff --git a/net/unix.c b/net/unix.c index 05ebc89..64b184f 100644 --- a/net/unix.c +++ b/net/unix.c @@ -32,6 +32,8 @@ static ssize_t unix_socket_sendto(struct socket *s, static ssize_t unix_socket_recvfrom(struct socket *s, void *buf, size_t lim, struct sockaddr *dst, size_t *salen); +static int unix_socket_count_pending(struct socket *s); +static struct io_notify *unix_socket_get_rx_notify(struct socket *s); static struct sockops unix_socket_ops = { .open = unix_socket_open, @@ -44,6 +46,9 @@ static struct sockops unix_socket_ops = { .connect = unix_socket_connect, .accept = unix_socket_accept, + + .count_pending = unix_socket_count_pending, + .get_rx_notify = unix_socket_get_rx_notify, }; static struct socket_class unix_socket_class = { .name = "unix", @@ -60,11 +65,16 @@ static ssize_t unix_conn_recvfrom(struct socket *s, void *buf, size_t lim, struct sockaddr *dst, size_t *salen); static void unix_conn_close(struct socket *s); +static int unix_conn_count_pending(struct socket *sock); +static struct io_notify *unix_conn_get_rx_notify(struct socket *sock); static struct sockops unix_conn_ops = { .sendto = unix_conn_sendto, .recvfrom = unix_conn_recvfrom, - .close = unix_conn_close + .close = unix_conn_close, + + .count_pending = unix_conn_count_pending, + .get_rx_notify = unix_conn_get_rx_notify, }; //// @@ -130,6 +140,7 @@ static void unix_socket_close(struct socket *sock) { _assert(vn); vn->fs_data = NULL; } else { + kdebug("Closing non-server socket\n"); // Hangup connection if (data->remote) { struct unix_conn *conn = data->remote; @@ -155,6 +166,28 @@ static void unix_socket_close(struct socket *sock) { kfree(data); } +static int unix_socket_count_pending(struct socket *sock) { + struct unix_socket *data = sock->data; + _assert(data); + + if (data->type == 1) { + return !!data->remote; + } else { + panic("TODO\n"); + } +} + +static struct io_notify *unix_socket_get_rx_notify(struct socket *sock) { + struct unix_socket *data = sock->data; + _assert(data); + + if (data->type == 1) { + return &data->client_notify; + } else { + panic("TODO\n"); + } +} + static void unix_conn_close(struct socket *s) { struct unix_conn *conn = s->data; _assert(conn); @@ -166,7 +199,7 @@ static void unix_conn_close(struct socket *s) { conn->state = STATE_TERMINATED; // Send EOF to remote _assert(conn->client); - ring_signal(&conn->server_tx, RING_SIGNAL_EOF); + ring_signal(&conn->server_tx, RING_SIGNAL_EOF | RING_SIGNAL_RET); conn->server = NULL; } @@ -180,6 +213,24 @@ static void unix_conn_close(struct socket *s) { } } +static int unix_conn_count_pending(struct socket *sock) { + struct unix_conn *conn = sock->data; + _assert(conn); + + if (conn->state != STATE_ESTABLISHED) { + return 1; + } + + return ring_readable(&conn->client_tx); +} + +static struct io_notify *unix_conn_get_rx_notify(struct socket *sock) { + struct unix_conn *conn = sock->data; + _assert(conn); + return &conn->client_tx.wait; +} + + static int unix_socket_bind(struct socket *sock, struct sockaddr *sa, size_t len) { if (sa->sa_family != AF_UNIX) { return -EINVAL; @@ -258,7 +309,7 @@ static int unix_socket_accept(struct socket *serv, struct socket *client) { _assert(data_serv); // Wait for incoming connection attempts - _assert(!(data_serv->remote)); + //_assert(!(data_serv->remote)); while (!(conn = data_serv->remote)) { thread_wait_io(thread_self, &data_serv->client_notify); } diff --git a/sys/sys_file.c b/sys/sys_file.c index 47e7a47..e3c0850 100644 --- a/sys/sys_file.c +++ b/sys/sys_file.c @@ -441,7 +441,9 @@ static int sys_select_get_ready(struct ofile *fd) { static struct io_notify *sys_select_get_wait(struct ofile *fd) { if (fd->flags & OF_SOCKET) { - return &fd->socket.rx_notify; + struct io_notify *res = socket_get_rx_notify(&fd->socket); + _assert(res); + return res; } else { struct vnode *vn = fd->file.vnode; _assert(vn);