tlstun

simple go program to add tls support to other listeners
git clone https://git.parazyd.org/tlstun
Log | Files | Refs | README | LICENSE

tlstun.go (3666B)


      1 // Copyright (c) 2019-2021 Ivan J. <parazyd@dyne.org>
      2 //
      3 // This file is part of tlstun
      4 //
      5 // This program is free software: you can redistribute it and/or modify
      6 // it under the terms of the GNU Affero General Public License as published by
      7 // the Free Software Foundation, either version 3 of the License, or
      8 // (at your option) any later version.
      9 //
     10 // This program is distributed in the hope that it will be useful,
     11 // but WITHOUT ANY WARRANTY; without even the implied warranty of
     12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
     13 // GNU Affero General Public License for more details.
     14 //
     15 // You should have received a copy of the GNU Affero General Public License
     16 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
     17 package main
     18 
     19 import (
     20 	"crypto/tls"
     21 	"crypto/x509"
     22 	"flag"
     23 	"fmt"
     24 	"io"
     25 	"io/ioutil"
     26 	"log"
     27 	"net"
     28 )
     29 
     30 var (
     31 	cacert  = flag.String("cacert", "ca.pem", "Path for CA certificate file")
     32 	cert    = flag.String("cert", "server.pem", "Path for Certificate file")
     33 	key     = flag.String("key", "server-key.pem", "Path for Key file")
     34 	listen  = flag.String("listen", "127.0.0.1:7443", "Listen address")
     35 	forward = flag.String("forward", "127.0.0.1:72", "Forward address")
     36 	fwtls   = flag.Bool("forwardtls", false, "Forward using TLS")
     37 	client  = flag.Bool("verifyclient", false, "Do client verification")
     38 	verbose = flag.Bool("verbose", false, "Verbose mode")
     39 	notls   = flag.Bool("notls", false, "Disable TLS and tunnel plain TCP")
     40 	tlsver  = flag.Int("tlsver", 13, "TLS version to use (11, 12, 13)")
     41 )
     42 
     43 func tlsConfig(cert, key string) (*tls.Config, error) {
     44 	creds, err := tls.LoadX509KeyPair(cert, key)
     45 	if err != nil {
     46 		return nil, err
     47 	}
     48 
     49 	tlscfg := &tls.Config{Certificates: []tls.Certificate{creds}}
     50 
     51 	if *client {
     52 		certpool, _ := x509.SystemCertPool()
     53 		if certpool == nil {
     54 			certpool = x509.NewCertPool()
     55 		}
     56 		pem, err := ioutil.ReadFile(*cacert)
     57 		if err != nil {
     58 			return nil, err
     59 		}
     60 		if !certpool.AppendCertsFromPEM(pem) {
     61 			return nil, fmt.Errorf("can't parse client certificate authority")
     62 		}
     63 		tlscfg.ClientCAs = certpool
     64 		tlscfg.ClientAuth = tls.RequireAndVerifyClientCert
     65 	}
     66 
     67 	switch *tlsver {
     68 	case 11:
     69 		tlscfg.MinVersion = tls.VersionTLS11
     70 	case 12:
     71 		tlscfg.MinVersion = tls.VersionTLS12
     72 	case 13:
     73 		tlscfg.MinVersion = tls.VersionTLS13
     74 	default:
     75 		log.Fatal("Unsupported TLS version:", *tlsver)
     76 	}
     77 
     78 	return tlscfg, nil
     79 }
     80 
     81 func tunnel(conn net.Conn, tlsCfg *tls.Config) {
     82 	var client net.Conn
     83 	var err error
     84 
     85 	if *fwtls {
     86 		client, err = tls.Dial("tcp", *forward, tlsCfg)
     87 	} else {
     88 		client, err = net.Dial("tcp", *forward)
     89 	}
     90 
     91 	if err != nil {
     92 		log.Fatal(err)
     93 	}
     94 
     95 	if *verbose {
     96 		log.Printf("Connected to localhost for %s\n", conn.RemoteAddr())
     97 	}
     98 
     99 	go func() {
    100 		defer client.Close()
    101 		defer conn.Close()
    102 		io.Copy(client, conn)
    103 	}()
    104 	go func() {
    105 		if *verbose {
    106 			defer log.Printf("Closed connection from %s\n", conn.RemoteAddr())
    107 		}
    108 		defer client.Close()
    109 		defer conn.Close()
    110 		io.Copy(conn, client)
    111 	}()
    112 }
    113 
    114 func server(tlsCfg *tls.Config) (net.Listener, error) {
    115 	t, err := net.Listen("tcp", *listen)
    116 	if err != nil {
    117 		return nil, err
    118 	}
    119 
    120 	if *notls {
    121 		return t, nil
    122 	}
    123 
    124 	return tls.NewListener(t, tlsCfg), nil
    125 }
    126 
    127 func main() {
    128 	flag.Parse()
    129 
    130 	var tlsCfg *tls.Config
    131 	var err error
    132 
    133 	if *notls {
    134 		tlsCfg = nil
    135 	} else {
    136 		tlsCfg, err = tlsConfig(*cert, *key)
    137 		if err != nil {
    138 			log.Fatal(err)
    139 		}
    140 	}
    141 
    142 	tcpsock, err := server(tlsCfg)
    143 	if err != nil {
    144 		log.Fatal(err)
    145 	}
    146 
    147 	for {
    148 		conn, err := tcpsock.Accept()
    149 		if err != nil {
    150 			log.Fatal(err)
    151 		}
    152 		if *verbose {
    153 			log.Printf("Accepted connection from %s\n", conn.RemoteAddr())
    154 		}
    155 		go tunnel(conn, tlsCfg)
    156 	}
    157 }