/*
    BFilter - a smart ad-filtering web proxy
    Copyright (C) 2002-2005  Joseph Artsimovich <joseph_a@mail.ru>

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
*/

#include "pch.h"

#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#include "FlashInspector.h"
#include "DeflateDecompressor.h"
#include "StringUtils.h"
#include <cstring>
#include <cmath>
#include <limits>

using namespace std;

struct FlashInspector::Rect
{
	int32_t left, right, top, bottom;
	
	int32_t getWidth() const { return right-left; }
	int32_t getHeight() const { return bottom-top; }
};


struct FlashInspector::Matrix
{
	int32_t xScale, yScale;
	int32_t xRotate, yRotate;
	int32_t xTranslate, yTranslate;
};


double const FlashInspector::BUTTON_AREA_THRESHOLD = 0.75;

inline void
FlashInspector::byteAlign(Position& pos)
{
	if (pos.bitpos > 0) {
		++pos.bytepos;
		pos.bitpos = 0;
	}
}

inline uint8_t
FlashInspector::readU8(DataIter& pos, DataIter const& end)
{
	if (pos == end) {
		throw OOB();
	}
	uint8_t val = *pos;
	++pos;
	return val;
}

inline int32_t
FlashInspector::readSBits(Position& pos, DataIter const& end, int nbits)
{
	int shift = std::numeric_limits<uint32_t>::digits - nbits;
	uint32_t val = readUBits(pos, end, nbits) << shift;
	return int32_t(val) >> shift;
}

FlashInspector::FlashInspector()
:	m_status(IN_PROGRESS),
	m_state(ST_HEADER),
	m_flashVersion(0),
	m_width(0),
	m_height(0),
	m_curFrame(0),
	m_isClickable(false)
{
}

FlashInspector::~FlashInspector()
{
}

void
FlashInspector::reset()
{
	m_status = IN_PROGRESS;
	m_state = ST_HEADER;
	m_incomingData.clear();
	m_data.clear();
	m_pDecomp.reset(0);
	m_shapes.clear();
	m_flashVersion = 0;
	m_width = 0;
	m_height = 0;
	m_curFrame = 0;
	m_isClickable = false;
}

void
FlashInspector::consumeDataChunk(SplittableBuffer& data, bool eof)
{
	if (m_status != IN_PROGRESS) {
		m_incomingData.appendDestructive(data);
		return;
	}
	
	m_incomingData.append(data);
	if (!m_pDecomp.get()) {
		m_data.appendDestructive(data);
	} else {
		m_pDecomp->consume(data, eof);
		m_pDecomp->retrieveOutput(m_data);
		if (m_pDecomp->isError()) {
			m_status = FAILED;
			return;
		}
	}
	
	processNewData(eof);
	
	if (m_status != IN_PROGRESS) {
		return;
	}
	
	if (eof || m_incomingData.size() > INSPECTION_SIZE_LIMIT) {
		m_status = m_state > ST_HEADER2 ? COMPLETE : FAILED;
	}
	
}

void
FlashInspector::processNewData(bool eof)
try {
	switch (m_state) {
		case ST_TAG_BODY:
		case_ST_TAG_BODY: {
			if (m_data.size() < m_tagLength) {
				return;
			}
			Position pos(m_data.begin());
			DataIter tag_end(pos.bytepos);
			tag_end += m_tagLength;
			try {
				switch(m_tagCode) {
					case 7: {
						processButton(pos, tag_end);
						break;
					}
					case 34: {
						processButton2(pos, tag_end);
						break;
					}
					case 4: {
						processPlaceObject(pos, tag_end);
						break;
					}
					case 26: {
						processPlaceObject2(pos, tag_end);
						break;
					}
					case 39: {
						processSprite(pos, tag_end);
						break;
					}
					case 2:
					case 22:
					case 32:
					case 46: {
						uint16_t char_id = readU16(pos.bytepos, tag_end);
						Rect bounds;
						readRect(bounds, pos, tag_end);
						m_shapes[char_id] = Shape(bounds.getWidth(), bounds.getHeight());
						break;
					}
					case 1: {
						if (++m_curFrame >= INSPECTION_FRAME_LIMIT) {
							m_status = COMPLETE;
							return;
						}
						break;
					}
					default: break;
				}
			} catch (OOB& e) { // broken file
				m_status = FAILED;
				return;
			}
			if (m_status != IN_PROGRESS) {
				return;
			}
			m_data.trimFront(tag_end);
			m_state = ST_TAG_HEADER;
			// fall through
		}
		case ST_TAG_HEADER:
		case_ST_TAG_HEADER: {
			DataIter pos(m_data.begin());
			DataIter const end(m_data.end());
			uint16_t code_and_length = readU16(pos, end);
			m_tagCode = code_and_length >> 6;
			m_tagLength = code_and_length & 0x3F;
			if (m_tagLength == 0x3F) {
				m_tagLength = readU32(pos, end);
			}
			m_data.trimFront(pos);
			if (code_and_length == 0) {
				m_status = COMPLETE;
				return;
			}
			m_state = ST_TAG_BODY;
			goto case_ST_TAG_BODY;
		}
		case ST_HEADER: {
			DataIter pos(m_data.begin());
			DataIter const end(m_data.end());
			uint8_t header[8];
			readBytes(pos, end, sizeof(header), header);
			if (!memcmp(header, "FWS", 3)) {
				m_data.trimFront(pos);
			} else if (!memcmp(header, "CWS", 3)) {
				m_data.trimFront(pos);
				m_pDecomp.reset(new DeflateDecompressor(DeflateDecompressor::ZLIB));
				m_pDecomp->consume(m_data, eof);
				m_pDecomp->retrieveOutput(m_data);
				if (m_pDecomp->isError()) {
					m_status = FAILED;
					return;
				}
			} else {
				// not a flash file
				m_status = FAILED;
				return;
			}
			m_flashVersion = header[3];
			m_state = ST_HEADER2;
			// fall through
		}
		case ST_HEADER2: {
			Position pos(m_data.begin());
			DataIter const end(m_data.end());
			Rect frame_rect;
			readRect(frame_rect, pos, end);
			byteAlign(pos);
			if (frame_rect.getWidth() < 0 || frame_rect.getHeight() < 0) {
				// broken file?
				m_status = FAILED;
				return;
			}
			m_frameArea = double(frame_rect.getWidth())*double(frame_rect.getHeight());
			m_width = frame_rect.getWidth() / 20;
			m_height = frame_rect.getHeight() / 20;
			m_frameRate = readU16(pos.bytepos, end);
			skipBytes(pos.bytepos, end, 2); // skip the frame count
			m_data.trimFront(pos.bytepos);
			m_state = ST_TAG_HEADER;
			goto case_ST_TAG_HEADER;
		}
	}
} catch (OOB& e) {
	// more data is needed
}

void
FlashInspector::processButton(Position& pos, DataIter const& end)
{
	uint16_t char_id = readU16(pos.bytepos, end);
	Shape& hit_area = m_shapes[char_id]; // will insert a record to m_shapes
	processButtonRecords(hit_area, pos, end, false);
	bool url_action_found = false;
	while (true) {
		uint8_t action_code = readU8(pos.bytepos, end);
		if (action_code == 0) {
			break;
		}
		if (!(action_code & 0x80)) {
			// short actions are of no interest to us
			continue;
		}
		size_t length = readU16(pos.bytepos, end);
		if (end - pos.bytepos < length) {
			throw OOB();
		}
		DataIter url_action_end(pos.bytepos);
		url_action_end += length;
		if (action_code == 0x83 || action_code == 0x9a) {
			url_action_found = processUrlAction(pos, url_action_end, action_code);
			if (url_action_found) {
				hit_area.btn_width = hit_area.width;
				hit_area.btn_height = hit_area.height;
				break;
			}
		}
		pos.bytepos = url_action_end;
	}
}

void
FlashInspector::processButton2(Position& pos, DataIter const& end)
{
	uint16_t char_id = readU16(pos.bytepos, end);
	if (readU8(pos.bytepos, end) & 1) { // menu button
		return;
	}
	Position button_actions(pos);
	size_t button_action_offset = readU16(pos.bytepos, end);
	if (end - button_actions.bytepos < button_action_offset) {
		throw OOB();
	}
	button_actions.bytepos += button_action_offset;
	DataIter characters_end(button_actions.bytepos);
	if (processButtonActions(button_actions, end)) {
		Shape& hit_area = m_shapes[char_id]; // will insert a record to m_shapes
		processButtonRecords(hit_area, pos, characters_end, true);
		hit_area.btn_width = hit_area.width;
		hit_area.btn_height = hit_area.height;
	}
}

bool
FlashInspector::processButtonActions(Position& pos, DataIter const& end)
{
	while (true) {
		DataIter action_end(pos.bytepos);
		size_t next_action_offset = readU16(pos.bytepos, end);
		if (next_action_offset == 0) {
			action_end = end;
		} else {
			if (end - action_end < next_action_offset) {
				throw OOB();
			}
			action_end += next_action_offset;
		}
		skipBytes(pos.bytepos, end, 2); // skip the condition
		if (processActionRecords(pos, action_end)) {
			return true;
		}
		pos.bytepos = action_end;
		if (next_action_offset == 0) {
			break;
		}
	}
	return false;
}

bool
FlashInspector::processActionRecords(Position& pos, DataIter const& end)
{
	while (true) {
		DataIter action_begin(pos.bytepos);
		uint8_t action_code = readU8(pos.bytepos, end);
		if (action_code == 0) {
			break;
		}
		if (action_code & 0x80) {
			size_t action_length = readU16(pos.bytepos, end);
			if (end - pos.bytepos < action_length) {
				throw OOB();
			}
			DataIter action_end(pos.bytepos);
			action_end += action_length;
			if (action_code == 0x83 || action_code == 0x9a) {
				if (processUrlAction(pos, action_end, action_code)) {
					return true;
				}
			}
			pos.bytepos = action_end;
			pos.bitpos = 0;
		}
	}
	return false;
}

bool
FlashInspector::processUrlAction(Position& pos, DataIter const& end, uint8_t code)
{
	if (code == 0x9a) {
		return true;
	} else if (code == 0x83) {
		DataIter url_end(SplittableBuffer::find(pos.bytepos, end, '\0'));
		if (url_end == end) {
			throw OOB();
		}
		static char const level[] = {'_','l','e','v','e','l'};
		return !StringUtils::startsWith(++url_end, end, level, level+sizeof(level));
	}
	return false;
}

// Fills the hit_area structure (width and height only)
bool
FlashInspector::processButtonRecords(Shape& hit_area, Position& pos, DataIter const& end, bool has_cxform)
{
	double max_area = std::numeric_limits<double>::min();
	while (true) {
		uint8_t flags = readU8(pos.bytepos, end);
		if (!flags) {
			break;
		}
		
		uint16_t char_id = readU16(pos.bytepos, end);
		skipBytes(pos.bytepos, end, 2); // skip depth
		Matrix mat;
		readMatrix(mat, pos, end);
		if (has_cxform) {
			skipColorTransformAlpha(pos, end);
		}
		byteAlign(pos);
		
		if (flags & (1|8)) { // Up or Hit state
			map<uint16_t, Shape>::iterator fit = m_shapes.find(char_id);
			if (fit == m_shapes.end()) {
				continue;
			}
			double width = fabs((1.0/65536)*mat.xScale) * fit->second.width;
			double height = fabs((1.0/65536)*mat.yScale) * fit->second.height;
			double area = width*height;
			if (area > max_area) {
				max_area = area;
				hit_area.width = static_cast<int32_t>(width);
				hit_area.height = static_cast<int32_t>(height);
			}
		}
	}
	return max_area != std::numeric_limits<double>::min();
}

void
FlashInspector::processPlaceObject(Position& pos, DataIter const& end, Shape* sprite)
{
	uint16_t char_id = readU16(pos.bytepos, end);
	skipBytes(pos.bytepos, end, 2); // skip the depth
	Matrix mat;
	readMatrix(mat, pos, end);
	byteAlign(pos);
	map<uint16_t, Shape>::iterator fit = m_shapes.find(char_id);
	if (fit != m_shapes.end()) {
		processObjectPlacement(fit->second, &mat, sprite);
	}
}

void
FlashInspector::processPlaceObject2(Position& pos, DataIter const& end, Shape* sprite)
{
	enum {
		FLAG_MOVE                = 1 << 0,
		FLAG_HAS_CHARACTER       = 1 << 1,
		FLAG_HAS_MATRIX          = 1 << 2,
		FLAG_HAS_COLOR_TRANSFORM = 1 << 3,
		FLAG_HAS_RATIO           = 1 << 4,
		FLAG_HAS_NAME            = 1 << 5,
		FLAG_HAS_CLIP_DEPTH      = 1 << 6,
		FLAG_HAS_CLIP_ACTIONS    = 1 << 7
	};
	
	uint8_t const flags = readU8(pos.bytepos, end);
	if (!(flags & FLAG_HAS_CHARACTER)) {
		return;
	}
	
	skipBytes(pos.bytepos, end, 2); // skip depth
	uint16_t char_id = readU16(pos.bytepos, end);
	map<uint16_t, Shape>::iterator fit = m_shapes.find(char_id);
	if (fit == m_shapes.end()) {
		return;
	}
	
	Shape& object = fit->second;
	
	Matrix matrix;
	Matrix* mat = 0;
	if (flags & FLAG_HAS_MATRIX) {
		readMatrix(matrix, pos, end);
		mat = &matrix;
	}
	
	if (flags & FLAG_HAS_CLIP_ACTIONS) {
		if (flags & FLAG_HAS_COLOR_TRANSFORM) {
			skipColorTransformAlpha(pos, end);
		}
		
		byteAlign(pos);
		
		if (flags & FLAG_HAS_RATIO) {
			skipBytes(pos.bytepos, end, 2);
		}
		
		if (flags & FLAG_HAS_NAME) {
			skipString(pos.bytepos, end);
		}
	
		if (flags & FLAG_HAS_CLIP_DEPTH) {
			skipBytes(pos.bytepos, end, 2);
		}
		
		if (processClipActions(pos, end)) {
			object.btn_width = object.width;
			object.btn_height = object.height;
		}
	}
	
	processObjectPlacement(object, mat, sprite);
}

bool
FlashInspector::processClipActions(Position& pos, DataIter const& end)
{
	uint32_t const mouse_flags = (1UL << 18) | (1UL << 19)
		| (1UL << 27) | (1UL << 28) | (1UL << 29);
	skipBytes(pos.bytepos, end, 2); // skip reserved bytes
	uint32_t const all_events = readClipEventFlags(pos, end);
	if (!(all_events & mouse_flags)) {
		// nothing interesting for us here
		return false;
	}
	
	bool url_action_found = false;

	while (!url_action_found) {
		uint32_t const events = readClipEventFlags(pos, end);
		if (events == 0) {
			break;
		}
		size_t actions_length = readU32(pos.bytepos, end);
		if (events & (1 << 9)) {
			skipBytes(pos.bytepos, end, 1); // skip key code
		}
		
		if (end - pos.bytepos < actions_length) {
			throw OOB();
		}
		DataIter actions_end(pos.bytepos);
		actions_end += actions_length;
		url_action_found = processActionRecords(pos, actions_end);
		pos.bytepos = actions_end;
		pos.bitpos = 0;
	}
	
	return url_action_found;
}

uint32_t
FlashInspector::readClipEventFlags(Position& pos, DataIter const& end)
{
	return readUBits(pos, end, (m_flashVersion < 6) ? 16 : 32);
}

void
FlashInspector::processObjectPlacement(
	Shape const& object, Matrix const* mat, Shape* sprite)
{
	double xscale = mat ? fabs((1.0/65536)*mat->xScale) : 1.0;
	double yscale = mat ? fabs((1.0/65536)*mat->yScale) : 1.0;
	
	double btn_width = object.btn_width * xscale;
	double btn_height = object.btn_height * yscale;
	double btn_area = btn_width * btn_height;
	
	if (!sprite) {
		if (btn_area >= BUTTON_AREA_THRESHOLD * m_frameArea) {
			m_status = COMPLETE;
			m_isClickable = true;
		}
	} else {
		double width = object.width * xscale;
		double height = object.height * yscale;
		if (width * height > sprite->getArea()) {
			sprite->width = int32_t(width);
			sprite->height = int32_t(height);
		}
		if (btn_area > sprite->getButtonArea()) {
			sprite->btn_width = int32_t(btn_width);
			sprite->btn_height = int32_t(btn_height);
		}
	}
}

void
FlashInspector::processSprite(Position& pos, DataIter const& end)
{
	uint16_t char_id = readU16(pos.bytepos, end);
	skipBytes(pos.bytepos, end, 2); // skip frame count
	Shape* sprite = &m_shapes[char_id]; // create a record
	
	while (pos.bytepos != end) {
		uint16_t code_and_length = readU16(pos.bytepos, end);
		if (code_and_length == 0) {
			break;
		}
		uint16_t tag_code = code_and_length >> 6;
		size_t tag_length = code_and_length & 0x3F;
		if (tag_length == 0x3F) {
			tag_length = readU32(pos.bytepos, end);
		}
		DataIter tag_end(pos.bytepos);
		tag_end += tag_length;
		if (tag_code == 4) {
			processPlaceObject(pos, tag_end, sprite);
		} else if (tag_code == 26) {
			processPlaceObject2(pos, tag_end, sprite);
		}
		pos.bytepos = tag_end;
		pos.bitpos = 0;
	}
}

void
FlashInspector::readRect(Rect& rect, Position& pos, DataIter const& end)
{
	uint32_t nbits = readUBits(pos, end, 5);
	rect.left = readSBits(pos, end, nbits);
	rect.right = readSBits(pos, end, nbits);
	rect.top = readSBits(pos, end, nbits);
	rect.bottom = readSBits(pos, end, nbits);
}

void
FlashInspector::readMatrix(Matrix& mat, Position& pos, DataIter const& end)
{
	byteAlign(pos);
	if (readUBits(pos, end, 1)) {
		int nScaleBits = readUBits(pos, end, 5);
		mat.xScale = readSBits(pos, end, nScaleBits);
		mat.yScale = readSBits(pos, end, nScaleBits);
	} else {
		mat.xScale = mat.yScale = 0x10000;
	}
	if (readUBits(pos, end, 1)) {
		int nRotateBits = readUBits(pos, end, 5);
		mat.xRotate = readSBits(pos, end, nRotateBits);
		mat.yRotate = readSBits(pos, end, nRotateBits);
	} else {
		mat.xRotate = mat.yRotate = 0x10000;
	}
	int nTranslateBits = readUBits(pos, end, 5);
	mat.xTranslate = readSBits(pos, end, nTranslateBits);
	mat.yTranslate = readSBits(pos, end, nTranslateBits);
}

void
FlashInspector::skipColorTransformAlpha(Position& pos, DataIter const& end)
{
	byteAlign(pos);
	uint8_t flags = readUBits(pos, end, 2);
	int const nbits = readUBits(pos, end, 4);
	int more_bits = 0;
	if (flags & 1) {
		more_bits += nbits << 2;
	}
	if (flags & 2) {
		more_bits += nbits << 2;
	}
	readUBits(pos, end, more_bits);
}

uint16_t
FlashInspector::readU16(DataIter& pos, DataIter const& end)
{
	uint16_t res = readU8(pos, end);
	res |= readU8(pos, end) << 8;
	return res;
}

uint32_t
FlashInspector::readU32(DataIter& pos, DataIter const& end)
{
	uint32_t res = readU8(pos, end);
	res |= readU8(pos, end) << 8;
	res |= readU8(pos, end) << 16;
	res |= readU8(pos, end) << 24;
	return res;
}

uint32_t
FlashInspector::readUBits(Position& pos, DataIter const& end, int nbits)
{
	uint32_t res = 0;
	while (nbits > 0) {
		if (pos.bytepos == end) {
			throw OOB();
		}
		uint32_t val = uint8_t(*pos.bytepos) & (uint8_t(-1) >> pos.bitpos);
		nbits -= 8 - pos.bitpos;
		if (nbits >= 0) {
			res |= val << nbits;
			pos.bitpos = 0;
			++pos.bytepos;
		} else {
			res |= val >> -nbits;
			pos.bitpos = nbits + 8;
			break;
		}
	}
	return res;
}

void
FlashInspector::readBytes(DataIter& pos, DataIter const& end, size_t size, uint8_t* dest)
{
	uint8_t* const dest_end = dest + size;
	for (; dest != dest_end; ++dest) {
		*dest = readU8(pos, end);
	}
}

void
FlashInspector::skipBytes(DataIter& pos, DataIter const& end, size_t size)
{
	for (size_t i = 0; i < size; ++i, ++pos) {
		if (pos == end) {
			throw OOB();
		}
	}
}

void
FlashInspector::skipString(DataIter& pos, DataIter const& end)
{
	char ch;
	do {
		if (pos == end) {
			throw OOB();
		}
		ch = *pos;
		++pos;
	} while (ch != '\0');
}
