#!/usr/bin/python

# Copyright 2007 by Tobia Conforto <tobia.conforto@gmail.com>
#
# This program is free software; you can redistribute it and/or modify it under the terms of the GNU General
# Public License as published by the Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the
# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
# for more details.
#
# You should have received a copy of the GNU General Public License along with this program.
# If not, see http://www.gnu.org/licenses/

# Versions: 0.1 2007-08-13 Initial release
#           0.2 2008-05-12 Small fixes for Zen Xtra models

from __future__ import division
import sys, os, codecs, array, time, operator, getopt
import LRU

class CFS:
	clusterSize = 8192
	cacheMem = 10 * 2**20 # keep 20MB of recently read clusters in ram

	def __init__(self, filename, offset = 0):
		'''Filename and optional offset where the CFS filesystem begins
		(offset of cluster -1, the one filled with 0xff)'''
		self.image = file(filename)
		self.offset = offset
		self.clusterCache = LRU.LRU(self.cacheMem // self.clusterSize)

	def __getitem__(self, key):
		'''Get the nth CFS cluster from the image and cache it for later usage.
		Accepts simple slices of clusters, but doesn't process negative indices.
		In any case it returns the requested data as a byte string.'''
		if isinstance(key, slice):
			cstart, cstop = key.start, key.stop
		else:
			cstart, cstop = key, key + 1
		data = ''
		for cluster in range(cstart, cstop):
			if cluster not in self.clusterCache:
				self.image.seek(self.offset + (cluster + 1) * self.clusterSize)
				self.clusterCache[cluster] = self.image.read(self.clusterSize)
			data += self.clusterCache[cluster]
		return data

	def get_byteswapped_data(self, cluster):
		'''Get the nth CFS cluster from the image, without caching it.
		Swap the position of every two bytes and return it as an array object.
		This method is designed for bulk file retrieving.'''
		a = array.array('H')
		self.image.seek(self.offset + (cluster + 1) * self.clusterSize)
		a.fromfile(self.image, self.clusterSize // 2)
		a.byteswap()
		return a

	def inode(self, cluster):
		return CFSInode(self, cluster)

def pdp_uint32(data, offset = 0):
	o2, o1, o4, o3 = map(ord, data[offset : offset + 4])
	return (o1 << 24) | (o2 << 16) | (o3 << 8) | o4

def pdp_uint16(data, offset = 0):
	o2, o1 = map(ord, data[offset : offset + 2])
	return (o1 << 8) | o2

def ucs2string(data, offset, length): # length in bytes
	return codecs.utf_16_le_decode(data[offset : offset + length])[0]

def pdp_getbit(bitmap, bit_no):
	return (pdp_uint32(bitmap, bit_no // 32 * 4) >> (bit_no % 32)) & 1

class CFSInode:
	filename = '(no filename)'
	filesize = 0
	path = []

	def __init__(self, cfs, cluster):
		self.cluster = cluster
		self.cfs = cfs
		inode = cfs[cluster]
		# reading misc flags and values
		assert pdp_uint32(inode[4:8]) == cluster # self-reference
		self.serial = pdp_uint32(inode, 0x78)
		# reading metadata
		count_metadata = pdp_uint32(inode, 0x7c)
		offset = 0x80
		self.metadata = {}
		for i in range(count_metadata):
			assert pdp_uint16(inode, offset) == 3
			length = pdp_uint16(inode, offset + 2)
			tag = ucs2string(inode, offset + 4, 4)
			self.metadata[tag] = inode[offset + 10 : offset + 10 + length]
			if tag == '07':
				self.filename = ucs2string(inode, offset + 10, length - 2)
			elif tag == '0=':
				self.path = ucs2string(inode, offset + 10, length - 2).strip('\\').split('\\')
			elif tag == '0>':
				self.filesize = pdp_uint32(inode, offset + 10)
			offset += 10 + length
		# collecting flat list of data clusters
		self.dataclusters = []
		pointerclusters = []
		for off in range(0x20, 0x4c + 1, 4):
			c = pdp_uint32(inode, off)
			if c != 0xFFFFFFFFL:
				self.dataclusters.append(c)
		second_class_chain = pdp_uint32(inode, 0x58)
		if second_class_chain != 0xFFFFFFFFL:
			pointerclusters.append(second_class_chain)
		third_class_chain = pdp_uint32(inode, 0x64)
		if third_class_chain != 0xFFFFFFFFL:
			for off in range(0, 0x2000, 4):
				c = pdp_uint32(cfs[third_class_chain], off)
				if c == 0xFFFFFFFFL:
					break
				pointerclusters.append(c)
		for pnt in pointerclusters:
			for off in range(0, 0x2000, 4):
				c = pdp_uint32(cfs[pnt], off)
				if c == 0xFFFFFFFFL:
					break
				self.dataclusters.append(c)
		# reading directory entries
		if not self.metadata: # any better way of telling dirs and files apart?
			count_direntries = pdp_uint32(self, 8)
			self.direntries = []
			found = 0
			assert len(self.dataclusters) % 8 == 0
			for block_no in range(len(self.dataclusters) // 8):
				block = self[block_no * 0x10000 : block_no * 0x10000 + 0x10000]
				bitmap = block[16 : 16 + 204]
				for n in range(1632):
					if pdp_getbit(bitmap, n):
						off = 220 + n * 40
						self.direntries.append(CFSDirEntry(cfs, block[off : off + 40]))
						found += 1
			assert found == count_direntries

	def __getitem__(self, key):
		'''Returns the given byte (or byte slice) from the file contents.'''
		if isinstance(key, slice):
			bstart, bstop = key.start, key.stop
		else:
			bstart, bstop = key, key + 1
		cs = self.cfs.clusterSize
		cstart = bstart // cs
		cstop = (bstop - 1) // cs + 1
		data = ''.join([ self.cfs[x] for x in self.dataclusters[cstart : cstop] ])
		return data[bstart - cs * cstart : bstop - cs * cstart]

class CFSDirEntry:
	def __init__(self, cfs, entrydata):
		self.cluster = pdp_uint32(entrydata) # cluster no. of the inode
		# length of full filename
		self.len_filename = pdp_uint16(entrydata, 4)
		# first 15 chars of filename
		self.shortname = ucs2string(entrydata, 8, min(30, self.len_filename * 2))

if __name__ == '__main__':

	# commandline arguments
	optlist, args = getopt.gnu_getopt(sys.argv[1:], 'o:')
	opts = dict(optlist)
	offset = int(opts.get('-o', 20 * 2**20))

	if len(args) != 3:
		print 'Usage: zenrecover.py [-o OFFSET] DISK_OR_IMAGE SECTION OUTPUT_DIR'
		print 'DISK_OR_IMAGE is the disk containing the filesystem, or an image thereof'
		print 'OFFSET is the offset at which the filesystem starts (in bytes, default 20M)'
		print 'SECTION is the section of the filesystem to recover: "archives" or "songs"'
		print 'OUTPUT_DIR is the directory in which to place the recovered files'
		sys.exit(1)

	cfs = CFS(args[0], offset)
	section = args[1]
	outdir = args[2]

	# find the root inode
	rootinode = None
	for c in range(4, 0x10000):
		if pdp_uint32(cfs[c][:4]) == 0x3bbe0ad9:
			i = cfs.inode(c)
			if i.serial != 0xFFFFFFFFL:
				print "Found inode at cluster 0x%x, but serial number is not -1" % c
				continue
			rootinode = i
			break
	if not rootinode:
		raise "Could not find the root inode"

	# find the root directories
	root = {}
	for entry in rootinode.direntries:
		root[entry.shortname] = entry.cluster

	# begin recovery
	dirinode = cfs.inode(root[section])
	os.makedirs(outdir)
	lastfiles = [(1,1)] # timing of latest few files recovered (size in bytes, time in secs)
	t = len(dirinode.direntries)
	for i, entry in enumerate(dirinode.direntries):
		if entry.shortname != '.':
			t0 = time.time()
			inode = cfs.inode(entry.cluster)
			print '\r%d%% %.1fMB/s "%s" (%.1fMB)\033[K' % (
					i * 100 // t,
					operator.truediv(*map(sum, zip(*lastfiles))) / 2**20,
					inode.filename[:50],
					inode.filesize / 2**20),
			sys.stdout.flush()
			path = os.path.join(outdir, *inode.path)
			try:
				os.makedirs(path)
			except:
				pass
			f = file(os.path.join(path, inode.filename), 'w')
			remaining = inode.filesize
			for c in inode.dataclusters:
				if remaining >= cfs.clusterSize:
					cfs.get_byteswapped_data(c).tofile(f)
				else:
					f.write(cfs.get_byteswapped_data(c).tostring()[:remaining])
				remaining -= min(cfs.clusterSize, remaining)
			f.close()
			assert remaining == 0
			if len(lastfiles) >= 32: #transfer speed is calculated on latest 32 files
				lastfiles.pop(0)
			lastfiles.append((inode.filesize, time.time() - t0))
	print '\rDone.\033[K'
