xref: /haiku/src/preferences/mail/DNSQuery.cpp (revision 97901ec593ec4dd50ac115c1c35a6d72f6e489a5)
1 #include "DNSQuery.h"
2 
3 #include <errno.h>
4 #include <stdio.h>
5 
6 #include <ByteOrder.h>
7 #include <FindDirectory.h>
8 #include <NetAddress.h>
9 #include <NetEndpoint.h>
10 #include <Path.h>
11 
12  #define DEBUG 1
13 
14 #undef PRINT
15 #ifdef DEBUG
16 #define PRINT(a...) printf(a)
17 #else
18 #define PRINT(a...)
19 #endif
20 
21 static vint32 gID = 1;
22 
23 
24 BRawNetBuffer::BRawNetBuffer()
25 {
26 	_Init(NULL, 0);
27 }
28 
29 
30 BRawNetBuffer::BRawNetBuffer(off_t size)
31 {
32 	_Init(NULL, 0);
33 	fBuffer.SetSize(size);
34 }
35 
36 
37 BRawNetBuffer::BRawNetBuffer(const void* buf, size_t size)
38 {
39 	_Init(buf, size);
40 }
41 
42 
43 status_t
44 BRawNetBuffer::AppendUint16(uint16 value)
45 {
46 	uint16 netVal = B_HOST_TO_BENDIAN_INT16(value);
47 	ssize_t sizeW = fBuffer.WriteAt(fWritePosition, &netVal, sizeof(uint16));
48 	if (sizeW == B_NO_MEMORY)
49 		return B_NO_MEMORY;
50 	fWritePosition += sizeof(uint16);
51 	return B_OK;
52 }
53 
54 
55 status_t
56 BRawNetBuffer::AppendString(const char* string)
57 {
58 	size_t length = strlen(string) + 1;
59 	ssize_t sizeW = fBuffer.WriteAt(fWritePosition, string, length);
60 	if (sizeW == B_NO_MEMORY)
61 		return B_NO_MEMORY;
62 	fWritePosition += length;
63 	return B_OK;
64 }
65 
66 
67 status_t
68 BRawNetBuffer::ReadUint16(uint16& value)
69 {
70 	uint16 netVal;
71 	ssize_t sizeW = fBuffer.ReadAt(fReadPosition, &netVal, sizeof(uint16));
72 	if (sizeW == 0)
73 		return B_ERROR;
74 	value= B_BENDIAN_TO_HOST_INT16(netVal);
75 	fReadPosition += sizeof(uint16);
76 	return B_OK;
77 }
78 
79 
80 status_t
81 BRawNetBuffer::ReadUint32(uint32& value)
82 {
83 	uint32 netVal;
84 	ssize_t sizeW = fBuffer.ReadAt(fReadPosition, &netVal, sizeof(uint32));
85 	if (sizeW == 0)
86 		return B_ERROR;
87 	value= B_BENDIAN_TO_HOST_INT32(netVal);
88 	fReadPosition += sizeof(uint32);
89 	return B_OK;
90 }
91 
92 
93 status_t
94 BRawNetBuffer::ReadString(BString& string)
95 {
96 	char* buffer = (char*)fBuffer.Buffer();
97 	buffer = &buffer[fReadPosition];
98 
99 	// if the string is compressed we have to follow the links to the
100 	// sub strings
101 	while (*buffer != 0) {
102 		if (uint8(*buffer) == 192) {
103 			// found a pointer mark
104 			buffer++;
105 			// pointer takes 2 byte
106 			fReadPosition = fReadPosition + 1;
107 			off_t pos = uint8(*buffer);
108 			_ReadSubString(string, pos);
109 			break;
110 		}
111 		string.Append(buffer, 1);
112 		buffer++;
113 		fReadPosition++;
114 	}
115 	fReadPosition++;
116 	return B_OK;
117 }
118 
119 
120 status_t
121 BRawNetBuffer::SkipReading(off_t skip)
122 {
123 	if (fReadPosition + skip > fBuffer.BufferLength())
124 		return B_ERROR;
125 	fReadPosition += skip;
126 	return B_OK;
127 }
128 
129 
130 void
131 BRawNetBuffer::_Init(const void* buf, size_t size)
132 {
133 	fWritePosition = 0;
134 	fReadPosition = 0;
135 	fBuffer.WriteAt(fWritePosition, buf, size);
136 }
137 
138 
139 void
140 BRawNetBuffer::_ReadSubString(BString& string, off_t pos)
141 {
142 	// sub strings have no links to other substrings so we can read it in one
143 	// piece
144 	char* buffer = (char*)fBuffer.Buffer();
145 	string.Append(&buffer[pos]);
146 }
147 
148 
149 // #pragma mark - DNSTools
150 
151 
152 status_t
153 DNSTools::GetDNSServers(BObjectList<BString>* serverList)
154 {
155 	// TODO: reading resolv.conf ourselves shouldn't be needed.
156 	// we should have some function to retrieve the dns list
157 #define	MATCH(line, name) \
158 	(!strncmp(line, name, sizeof(name) - 1) && \
159 	(line[sizeof(name) - 1] == ' ' || \
160 	 line[sizeof(name) - 1] == '\t'))
161 
162 	BPath path;
163 	if (find_directory(B_COMMON_SETTINGS_DIRECTORY, &path) != B_OK)
164 		return B_ENTRY_NOT_FOUND;
165 
166 	path.Append("network/resolv.conf");
167 
168 	register FILE* fp = fopen(path.Path(), "r");
169 	if (fp == NULL) {
170 		fprintf(stderr, "failed to open '%s' to read nameservers: %s\n",
171 			path.Path(), strerror(errno));
172 		return B_ENTRY_NOT_FOUND;
173 	}
174 
175 	int nserv = 0;
176 	char buf[1024];
177 	register char *cp; //, **pp;
178 //	register int n;
179 	int MAXNS = 2;
180 
181 	// read the config file
182 	while (fgets(buf, sizeof(buf), fp) != NULL) {
183 		// skip comments
184 		if (*buf == ';' || *buf == '#')
185 			continue;
186 
187 		// read nameservers to query
188 		if (MATCH(buf, "nameserver") && nserv < MAXNS) {
189 //			char sbuf[2];
190 			cp = buf + sizeof("nameserver") - 1;
191 			while (*cp == ' ' || *cp == '\t')
192 				cp++;
193 			cp[strcspn(cp, ";# \t\n")] = '\0';
194 			if ((*cp != '\0') && (*cp != '\n')) {
195 				serverList->AddItem(new BString(cp));
196 				nserv++;
197 			}
198 		}
199 		continue;
200 	}
201 
202 	fclose(fp);
203 
204 	return B_OK;
205 }
206 
207 
208 BString
209 DNSTools::ConvertToDNSName(const BString& string)
210 {
211 	BString outString = string;
212 	int32 dot, lastDot, diff;
213 
214 	dot = string.FindFirst(".");
215 	if (dot != B_ERROR) {
216 		outString.Prepend((char*)&dot, 1);
217 		// because we prepend a char add 1 more
218 		lastDot = dot + 1;
219 
220 		while (true) {
221 			dot = outString.FindFirst(".", lastDot + 1);
222 			if (dot == B_ERROR)
223 				break;
224 
225 			// set a counts to the dot
226 			diff =  dot - 1 - lastDot;
227 			outString[lastDot] = (char)diff;
228 			lastDot = dot;
229 		}
230 	} else
231 		lastDot = 0;
232 
233 	diff = outString.CountChars() - 1 - lastDot;
234 	outString[lastDot] = (char)diff;
235 
236 	return outString;
237 }
238 
239 
240 BString
241 DNSTools::ConvertFromDNSName(const BString& string)
242 {
243 	BString outString = string;
244 	int32 dot = string[0];
245 	int32 nextDot = dot;
246 	outString.Remove(0, sizeof(char));
247 	while (true) {
248 		dot = outString[nextDot];
249 		if (dot == 0)
250 			break;
251 		// set a "."
252 		outString[nextDot] = '.';
253 		nextDot+= dot + 1;
254 	}
255 	return outString;
256 }
257 
258 
259 // #pragma mark - DNSQuery
260 // see http://tools.ietf.org/html/rfc1035 for more information about DNS
261 
262 
263 DNSQuery::DNSQuery()
264 {
265 }
266 
267 
268 DNSQuery::~DNSQuery()
269 {
270 }
271 
272 
273 status_t
274 DNSQuery::ReadDNSServer(in_addr* add)
275 {
276 	// list owns the items
277 	BObjectList<BString> dnsServerList(5, true);
278 	status_t status = DNSTools::GetDNSServers(&dnsServerList);
279 	if (status != B_OK)
280 		return status;
281 
282 	BString* firstDNS = dnsServerList.ItemAt(0);
283 	if (firstDNS == NULL || inet_aton(firstDNS->String(), add) != 1)
284 		return B_ERROR;
285 
286 	PRINT("dns server found: %s \n", firstDNS->String());
287 	return B_OK;
288 }
289 
290 
291 status_t
292 DNSQuery::GetMXRecords(BString serverName, BObjectList<mx_record>* mxList,
293 	bigtime_t timeout)
294 {
295 	// get the DNS server to ask for the mx record
296 	in_addr dnsAddress;
297 	if (ReadDNSServer(&dnsAddress) != B_OK)
298 		return B_ERROR;
299 
300 	// create dns query package
301 	BRawNetBuffer buffer;
302 	dns_header header;
303 	_SetMXHeader(&header);
304 	_AppendQueryHeader(buffer, &header);
305 
306 	BString serverNameConv = DNSTools::ConvertToDNSName(serverName);
307 	buffer.AppendString(serverNameConv.String());
308 	buffer.AppendUint16(uint16(MX_RECORD));
309 	buffer.AppendUint16(uint16(1));
310 
311 	// send the buffer
312 	PRINT("send buffer\n");
313 	BNetAddress netAddress(dnsAddress, 53);
314 	BNetEndpoint netEndpoint(SOCK_DGRAM);
315 	if (netEndpoint.InitCheck() != B_OK)
316 		return B_ERROR;
317 
318 	if (netEndpoint.Connect(netAddress) != B_OK)
319 		return B_ERROR;
320 	PRINT("Connected\n");
321 
322 #ifdef DEBUG
323 	int32 bytesSend =
324 #endif
325 	netEndpoint.Send(buffer.Data(), buffer.Size());
326 	PRINT("bytes send %i\n", int(bytesSend));
327 
328 	// receive buffer
329 	BRawNetBuffer receiBuffer(512);
330 	netEndpoint.SetTimeout(timeout);
331 #ifdef DEBUG
332 	int32 bytesRecei =
333 #endif
334 	netEndpoint.ReceiveFrom(receiBuffer.Data(), 512, netAddress);
335 	PRINT("bytes received %i\n", int(bytesRecei));
336 	dns_header receiHeader;
337 
338 	_ReadQueryHeader(receiBuffer, &receiHeader);
339 	PRINT("Package contains :");
340 	PRINT("%d Questions, ", receiHeader.q_count);
341 	PRINT("%d Answers, ", receiHeader.ans_count);
342 	PRINT("%d Authoritative Servers, ", receiHeader.auth_count);
343 	PRINT("%d Additional records\n", receiHeader.add_count);
344 
345 	// remove name and Question
346 	BString dummyS;
347 	uint16 dummy;
348 	receiBuffer.ReadString(dummyS);
349 	receiBuffer.ReadUint16(dummy);
350 	receiBuffer.ReadUint16(dummy);
351 
352 	bool mxRecordFound = false;
353 	for (int i = 0; i < receiHeader.ans_count; i++) {
354 		resource_record_head rrHead;
355 		_ReadResourceRecord(receiBuffer, &rrHead);
356 		if (rrHead.type == MX_RECORD) {
357 			mx_record *mxRec = new mx_record;
358 			_ReadMXRecord(receiBuffer, mxRec);
359 			PRINT("MX record found pri %i, name %s\n",
360 					mxRec->priority,
361 					mxRec->serverName.String());
362 			// Add mx record to the list
363 			mxList->AddItem(mxRec);
364 			mxRecordFound = true;
365 		} else {
366 			buffer.SkipReading(rrHead.dataLength);
367 		}
368 	}
369 
370 	if (!mxRecordFound)
371 		return B_ERROR;
372 
373 	return B_OK;
374 }
375 
376 
377 uint16
378 DNSQuery::_GetUniqueID()
379 {
380 	int32 nextId= atomic_add(&gID, 1);
381 	// just to be sure
382 	if (nextId > 65529)
383 		nextId = 0;
384 	return nextId;
385 }
386 
387 
388 void
389 DNSQuery::_SetMXHeader(dns_header* header)
390 {
391 	header->id = _GetUniqueID();
392 	header->qr = 0;      //This is a query
393 	header->opcode = 0;  //This is a standard query
394 	header->aa = 0;      //Not Authoritative
395 	header->tc = 0;      //This message is not truncated
396 	header->rd = 1;      //Recursion Desired
397 	header->ra = 0;      //Recursion not available! hey we dont have it (lol)
398 	header->z  = 0;
399 	header->rcode = 0;
400 	header->q_count = 1;   //we have only 1 question
401 	header->ans_count  = 0;
402 	header->auth_count = 0;
403 	header->add_count  = 0;
404 }
405 
406 
407 void
408 DNSQuery::_AppendQueryHeader(BRawNetBuffer& buffer, const dns_header* header)
409 {
410 	buffer.AppendUint16(header->id);
411 	uint16 data = 0;
412 	data |= header->rcode;
413 	data |= header->z << 4;
414 	data |= header->ra << 7;
415 	data |= header->rd << 8;
416 	data |= header->tc << 9;
417 	data |= header->aa << 10;
418 	data |= header->opcode << 11;
419 	data |= header->qr << 15;
420 	buffer.AppendUint16(data);
421 	buffer.AppendUint16(header->q_count);
422 	buffer.AppendUint16(header->ans_count);
423 	buffer.AppendUint16(header->auth_count);
424 	buffer.AppendUint16(header->add_count);
425 }
426 
427 
428 void
429 DNSQuery::_ReadQueryHeader(BRawNetBuffer& buffer, dns_header* header)
430 {
431 	buffer.ReadUint16(header->id);
432 	uint16 data = 0;
433 	buffer.ReadUint16(data);
434 	header->rcode = data & 0x0F;
435 	header->z = (data >> 4) & 0x07;
436 	header->ra = (data >> 7) & 0x01;
437 	header->rd = (data >> 8) & 0x01;
438 	header->tc = (data >> 9) & 0x01;
439 	header->aa = (data >> 10) & 0x01;
440 	header->opcode = (data >> 11) & 0x0F;
441 	header->qr = (data >> 15) & 0x01;
442 	buffer.ReadUint16(header->q_count);
443 	buffer.ReadUint16(header->ans_count);
444 	buffer.ReadUint16(header->auth_count);
445 	buffer.ReadUint16(header->add_count);
446 }
447 
448 
449 void
450 DNSQuery::_ReadMXRecord(BRawNetBuffer& buffer, mx_record* mxRecord)
451 {
452 	buffer.ReadUint16(mxRecord->priority);
453 	buffer.ReadString(mxRecord->serverName);
454 	mxRecord->serverName = DNSTools::ConvertFromDNSName(mxRecord->serverName);
455 }
456 
457 
458 void
459 DNSQuery::_ReadResourceRecord(BRawNetBuffer& buffer,
460 	resource_record_head *rrHead)
461 {
462 	buffer.ReadString(rrHead->name);
463 	buffer.ReadUint16(rrHead->type);
464 	buffer.ReadUint16(rrHead->dataClass);
465 	buffer.ReadUint32(rrHead->ttl);
466 	buffer.ReadUint16(rrHead->dataLength);
467 }
468