#!/usr/bin/ruby
#
# Details: Fgen subsystem, Tcpdump rule generator.
# Authors: Dmitry Maksyoma <dmaks@esphion.com>
# Started: 04/07/2006
#
# (c) 2006 Esphion Limited
# Copyright in this document, whether in written, electronic or other
# format including any source code or other computer code set forth
# therein or attached thereto and copyright in all parts thereof is owned
# by Esphion Limited, who reserves all rights therein.
# In particular, no part of this document/computer code may be reproduced,
# copied, stored in a retrieval system, used or transmitted by any means
# whatsoever without the prior written consent of Esphion Limited.
#
# All enquiries in relation thereto should be directed to:
#
# Esphion Ltd
# P.O. Box 300496,
# Albany,
# Auckland,
# New Zealand.
#
# Phone: +64 9 415 0227
# Fax:	 +64 9 415 0228
# Email: info@esphion.com

$:.insert(-2, File.join(File.dirname(File.expand_path($0)), "../lib"))
require 'pattern/entities'
require 'fgen/common'
include Entities

class FieldBytes < Array; attr_accessor :offset; end

# Since payload doesn't have its identifier in tcpdump proto[index]
# expression, payload is referenced through the protocol name of the layer
# it belongs to. `baseOffset' defines where data starts in that layer.
def fmt_header(header, proto, baseOffset=0, override=nil)
  expr = []

  layer = unless header.is_a?(EtherHeader)
    header.layer
  else
    # If VLAN and/or MPLS headers are present in ethernet, all of them must be
    # included into resulting filter. For example, if MPLS headers contains
    # 2 stacks, `mpls and mpls' must be added to filter.
    mplsCount, rest, offset, vlanCount = 0, [], 13, 0
    while val = header.layer[offset += 1]
      rest << val
    end

    if header.type_mpls
      mplsCount += 1
    elsif header.type_vlan
      while hdr = rest.slice!(0, 4)
	vlanCount += 1

	break unless hdr.size == 4
	etherType = Entities.merge_bytes(hdr[2, 2])
	next if etherType == EtherHeader::TYPE_VLAN

	mplsCount += 1 if etherType == EtherHeader::TYPE_MPLS_UNICAST
	break
      end
    end

    if mplsCount > 0
      while stack = rest.slice!(0, 4)
	break unless stack[2] && (stack[2] & 1) == 0

	mplsCount += 1
      end
    end

    vlanCount.times { expr.push 'vlan' }
    mplsCount.times { expr.push 'mpls' }

    header.layer.delete_if { |offset,| offset > 13 }
    header.layer
  end

  # Leave only the largest sized field for each offset.
  fields = {}
  header.class.fields.each { |fld|
    offset = fld[1]
    if (stored = fields[offset]).nil?
      fields[offset] = fld
    elsif stored.last < fld.last
      fields[offset] = fld 
    end
  }
  # Sort fields by offset.
  fields = fields.values.sort { |a, b| a[1] <=> b[1] }

  # Remove smaller overlapping fields.
  offset = size = nil
  fields = fields.select { |field|
    unless offset
      offset, size = field[1, 2]
      next field
    end

    if offset + size - 1 < field[1]
      offset, size = field[1, 2]
      field
    else
      offset += field[2]
      size -= field[2]
      nil
    end
  }

  # There still may be space left in `header' after the last field. In that
  # case, a field `unmarked', describing it, is added to `fields'. Payload
  # is a special case - unmarked space starts at 0 offset since no fields
  # exist in payload.
  lastField = fields.last
  spaceStart = lastField ? lastField[1] + lastField[2] : 0
  layerEnd = layer.keys.max
  if spaceStart <= layerEnd
    p = FieldBytes.new ['unmarked', spaceStart, layerEnd - spaceStart + 1]
    p.offset = spaceStart
    fields << p
  end

  fields.each { |field|
    name, offset, size = field
    #puts "name: #{name}, offset: #{offset}, size: #{size}"
    parts = []
    part = nil

    if override && $_ = override[offset]
      expr.push $_
      next
    end

    bytes = name != 'unmarked' ? header[name.intern] :
      Pattern.get_range(layer, offset, size)
    bytes = [bytes] unless bytes.is_a?(Array)

    # Split non-nil consecutive bytes to a number of 4-sized (at most)
    # arrays.
    bytes.each { |b|
      unless b
	part = nil
      else
	part = nil if part && part.size == 4

	unless part
	  parts.push(part = FieldBytes.new)
	  part.offset = offset
	end

	part << b
      end

      offset += 1
    }

    # Only parts sized 1, 2 and 4 are allowed. Thus, parts with 3 bytes is
    # split into 2 parts, the first with 2 bytes and second wsith 1 byte.
    tmp = []
    parts.each { |part|
      unless part.size == 3
	tmp.push part
      else
	p = FieldBytes.new
	p.push part.pop
	p.offset = part.offset + 2
	tmp.push part, p
      end
    }
    parts = tmp

    parts.each { |part|
      offset = if baseOffset.is_a?(Integer)
	part.offset + baseOffset
      elsif baseOffset.is_a?(String)
	baseOffset + " + #{part.offset}"
      end
      expr.push "#{proto}[#{offset}" \
	"#{part.size > 1 ? ':' + part.size.to_s : ''}] = " \
	"#{'0x%x' % [Entities.merge_bytes(part)]}"
    }
  }

  expr.join ' and '
end

# Make a tcpdump expression out of array of IPs.
def fmt_ips ips, offset
  ips.map! { |ip|
    ip = ip.split('.').inject(0) { |sum, b|
      sum <<= 8
      sum += b.to_i
      sum
    }
    "ip[#{offset}:4] == 0x#{'%x' % ip}"
  }
  if ips.size == 1
    ips.first
  else
    '(' << ips.join(' or ') << ')'
  end
end

def format_pattern(pattern, name, src_ips, dst_ips)
  # Ignore link layer if command name is netx-pktcap.
  link = pattern.link unless File.basename($0) == 'netx-pktcap'
  ipl = pattern.ip
  ippl = pattern.ipp
  payload = pattern.packetPayload

  proto = IPHeader.proto_desc[ipl.proto] if ipl

  expr = []
  expr << fmt_header(link, 'ether') if link
  if ipl
    override = {}
    override[12] = fmt_ips src_ips, 12 if src_ips && !ipl.srcip
    override[16] = fmt_ips(dst_ips, 16) if dst_ips && !ipl.dstip
    override = nil if override.empty?
    expr << fmt_header(ipl, 'ip', 0, override)
  end
  expr << fmt_header(ippl, proto) if ippl

  # Add payload layer to output if possible.
  if payload && ipl
    # Use `proto' as an indication of presence of IP protocol layer in a
    # packet. If `proto' is nil, payload is for IP layer.
    if proto.nil?
      expr << fmt_header(payload, 'ip', ipl.dataOffset) \
        if ipl.dataOffset
    elsif ipl.proto_icmp || ipl.proto_tcp || ipl.proto_udp
      if ipl.proto_tcp && (ippl.nil? || ippl.dataOffset.nil?)
	expr << fmt_header(payload, proto, '((tcp[12] & 0xF0) >> 2)')
      else
	offset = if ippl
	  ippl.dataOffset 
	else
	  case proto
	  when 'icmp'
	    ICMPHeader.dataOffset
	  when 'udp'
	    UDPHeader.dataOffset
	  end
	end
	expr << fmt_header(payload, proto, offset) if offset
      end
    end
  end

  expr.join ' and '
end

do_format
