xref: /openbsd-src/regress/sys/kern/unfdpass/unfdpass.c (revision 5a38ef86d0b61900239c7913d24a05e7b88a58f0)
1 /*	$OpenBSD: unfdpass.c,v 1.22 2021/12/13 16:56:50 deraadt Exp $	*/
2 /*	$NetBSD: unfdpass.c,v 1.3 1998/06/24 23:51:30 thorpej Exp $	*/
3 
4 /*-
5  * Copyright (c) 1998 The NetBSD Foundation, Inc.
6  * All rights reserved.
7  *
8  * This code is derived from software contributed to The NetBSD Foundation
9  * by Jason R. Thorpe of the Numerical Aerospace Simulation Facility,
10  * NASA Ames Research Center.
11  *
12  * Redistribution and use in source and binary forms, with or without
13  * modification, are permitted provided that the following conditions
14  * are met:
15  * 1. Redistributions of source code must retain the above copyright
16  *    notice, this list of conditions and the following disclaimer.
17  * 2. Redistributions in binary form must reproduce the above copyright
18  *    notice, this list of conditions and the following disclaimer in the
19  *    documentation and/or other materials provided with the distribution.
20  *
21  * THIS SOFTWARE IS PROVIDED BY THE NETBSD FOUNDATION, INC. AND CONTRIBUTORS
22  * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
23  * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
24  * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR CONTRIBUTORS
25  * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
26  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
27  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
28  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
29  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
30  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31  * POSSIBILITY OF SUCH DAMAGE.
32  */
33 
34 /*
35  * Test passing of file descriptors over Unix domain sockets and socketpairs.
36  */
37 
38 #include <sys/socket.h>
39 #include <sys/time.h>
40 #include <sys/wait.h>
41 #include <sys/un.h>
42 #include <err.h>
43 #include <errno.h>
44 #include <fcntl.h>
45 #include <signal.h>
46 #include <stdio.h>
47 #include <stdlib.h>
48 #include <string.h>
49 #include <unistd.h>
50 
51 #define	SOCK_NAME	"test-sock"
52 
53 int	main(int, char *[]);
54 void	child(int, int, int);
55 void	catch_sigchld(int);
56 
57 /* ARGSUSED */
58 int
59 main(int argc, char *argv[])
60 {
61 	struct msghdr msg;
62 	int sock, pfd[2], fd, i;
63 	int listensock = -1;
64 	char fname[16], buf[64];
65 	struct cmsghdr *cmp;
66 	int *files = NULL;
67 	struct sockaddr_un sun, csun;
68 	int csunlen;
69 	pid_t pid;
70 	union {
71 		struct cmsghdr hdr;
72 		char buf[CMSG_SPACE(sizeof(int) * 3)];
73 	} cmsgbuf;
74 	int pflag, oflag, rflag;
75 	int type = SOCK_STREAM;
76 	extern char *__progname;
77 
78 	pflag = 0;
79 	oflag = 0;
80 	rflag = 0;
81 	while ((i = getopt(argc, argv, "opqr")) != -1) {
82 		switch (i) {
83 		case 'o':
84 			oflag = 1;
85 			break;
86 		case 'p':
87 			pflag = 1;
88 			break;
89 		case 'q':
90 			type = SOCK_SEQPACKET;
91 			break;
92 		case 'r':
93 			rflag = 1;
94 			break;
95 		default:
96 			fprintf(stderr, "usage: %s [-opqr]\n", __progname);
97 			exit(1);
98 		}
99 	}
100 
101 	/*
102 	 * Create the test files.
103 	 */
104 	for (i = 0; i < 5; i++) {
105 		(void) snprintf(fname, sizeof fname, "file%d", i + 1);
106 		if ((fd = open(fname, O_WRONLY|O_CREAT|O_TRUNC, 0666)) == -1)
107 			err(1, "open %s", fname);
108 		(void) snprintf(buf, sizeof buf, "This is file %d.\n", i + 1);
109 		if (write(fd, buf, strlen(buf)) != (ssize_t) strlen(buf))
110 			err(1, "write %s", fname);
111 		(void) close(fd);
112 	}
113 
114 	if (pflag) {
115 		/*
116 		 * Create the socketpair
117 		 */
118 		if (socketpair(PF_LOCAL, type, 0, pfd) == -1)
119 			err(1, "socketpair");
120 	} else {
121 		/*
122 		 * Create the listen socket.
123 		 */
124 		if ((listensock = socket(PF_LOCAL, type, 0)) == -1)
125 			err(1, "socket");
126 
127 		(void) unlink(SOCK_NAME);
128 		(void) memset(&sun, 0, sizeof(sun));
129 		sun.sun_family = AF_LOCAL;
130 		(void) strlcpy(sun.sun_path, SOCK_NAME, sizeof sun.sun_path);
131 
132 		if (bind(listensock, (struct sockaddr *)&sun, sizeof(sun)) == -1)
133 			err(1, "bind");
134 
135 		if (listen(listensock, 1) == -1)
136 			err(1, "listen");
137 		pfd[0] = pfd[1] = -1;
138 	}
139 
140 	/*
141 	 * Create the sender.
142 	 */
143 	(void) signal(SIGCHLD, catch_sigchld);
144 	pid = fork();
145 	switch (pid) {
146 	case -1:
147 		err(1, "fork");
148 		/* NOTREACHED */
149 
150 	case 0:
151 		if (pfd[0] != -1)
152 			close(pfd[0]);
153 		child(pfd[1], type, oflag);
154 		/* NOTREACHED */
155 	}
156 
157 	if (pfd[0] != -1) {
158 		close(pfd[1]);
159 		sock = pfd[0];
160 	} else {
161 		/*
162 		 * Wait for the sender to connect.
163 		 */
164 		if ((sock = accept(listensock, (struct sockaddr *)&csun,
165 		    &csunlen)) == -1)
166 		err(1, "accept");
167 	}
168 
169 	/*
170 	 * Give sender a chance to run.  We will get going again
171 	 * once the SIGCHLD arrives.
172 	 */
173 	(void) sleep(10);
174 
175 	if (rflag) {
176 		if (read(sock, buf, sizeof(buf)) < 0)
177 			err(1, "read");
178 		printf("read successfully returned\n");
179 		exit(0);
180 	}
181 
182 	/*
183 	 * Grab the descriptors passed to us.
184 	 */
185 	memset(&msg, 0, sizeof(msg));
186 	msg.msg_control = &cmsgbuf.buf;
187 	msg.msg_controllen = sizeof(cmsgbuf.buf);
188 
189 	if (recvmsg(sock, &msg, 0) < 0) {
190 		if (errno == EMSGSIZE) {
191 			printf("recvmsg returned EMSGSIZE\n");
192 			exit(0);
193 		} else
194 			err(1, "recvmsg");
195 	}
196 
197 	(void) close(sock);
198 
199 	if (msg.msg_controllen == 0)
200 		errx(1, "no control messages received");
201 
202 	if (msg.msg_flags & MSG_CTRUNC)
203 		errx(1, "lost control message data");
204 
205 	for (cmp = CMSG_FIRSTHDR(&msg); cmp != NULL;
206 	    cmp = CMSG_NXTHDR(&msg, cmp)) {
207 		if (cmp->cmsg_level != SOL_SOCKET)
208 			errx(1, "bad control message level %d",
209 			    cmp->cmsg_level);
210 
211 		switch (cmp->cmsg_type) {
212 		case SCM_RIGHTS:
213 			if (cmp->cmsg_len != CMSG_LEN(sizeof(int) * 3))
214 				errx(1, "bad fd control message length %d",
215 				    cmp->cmsg_len);
216 
217 			files = (int *)CMSG_DATA(cmp);
218 			break;
219 
220 		default:
221 			errx(1, "unexpected control message");
222 			/* NOTREACHED */
223 		}
224 	}
225 
226 	/*
227 	 * Read the files and print their contents.
228 	 */
229 	if (files == NULL)
230 		warnx("didn't get fd control message");
231 	else {
232 		for (i = 0; i < 3; i++) {
233 			(void) memset(buf, 0, sizeof(buf));
234 			if (read(files[i], buf, sizeof(buf)) <= 0)
235 				err(1, "read file %d (%d)", i + 1, files[i]);
236 			printf("%s", buf);
237 		}
238 	}
239 
240 	/*
241 	 * All done!
242 	 */
243 	exit(0);
244 }
245 
246 void
247 catch_sigchld(sig)
248 	int sig;
249 {
250 	int save_errno = errno;
251 	int status;
252 
253 	(void) wait(&status);
254 	errno = save_errno;
255 }
256 
257 void
258 child(int sock, int type, int oflag)
259 {
260 	struct msghdr msg;
261 	char fname[16];
262 	struct cmsghdr *cmp;
263 	int i, fd, nfds = 3;
264 	struct sockaddr_un sun;
265 	size_t len;
266 	char *cmsgbuf;
267 	int *files;
268 
269 	/*
270 	 * Create socket if needed and connect to the receiver.
271 	 */
272 	if (sock == -1) {
273 		if ((sock = socket(PF_LOCAL, type, 0)) == -1)
274 			err(1, "child socket");
275 
276 		(void) memset(&sun, 0, sizeof(sun));
277 		sun.sun_family = AF_LOCAL;
278 		(void) strlcpy(sun.sun_path, SOCK_NAME, sizeof sun.sun_path);
279 
280 		if (connect(sock, (struct sockaddr *)&sun, sizeof(sun)) == -1)
281 			err(1, "child connect");
282 	}
283 
284 	if (oflag)
285 		nfds = 5;
286 	len = CMSG_SPACE(sizeof(int) * nfds);
287 	if ((cmsgbuf = malloc(len)) == NULL)
288 		err(1, "child");
289 
290 	(void) memset(&msg, 0, sizeof(msg));
291 	msg.msg_control = cmsgbuf;
292 	msg.msg_controllen = len;
293 
294 	cmp = CMSG_FIRSTHDR(&msg);
295 	cmp->cmsg_len = CMSG_LEN((sizeof(int) * nfds));
296 	cmp->cmsg_level = SOL_SOCKET;
297 	cmp->cmsg_type = SCM_RIGHTS;
298 
299 	/*
300 	 * Open the files again, and pass them to the parent over the socket.
301 	 */
302 	files = (int *)CMSG_DATA(cmp);
303 	for (i = 0; i < nfds; i++) {
304 		(void) snprintf(fname, sizeof fname, "file%d", i + 1);
305 		if ((fd = open(fname, O_RDONLY)) == -1)
306 			err(1, "child open %s", fname);
307 		files[i] = fd;
308 	}
309 
310 	if (sendmsg(sock, &msg, 0))
311 		err(1, "child sendmsg");
312 
313 	/*
314 	 * All done!
315 	 */
316 	exit(0);
317 }
318