前言

看似可能是简单的需求,实际上只有简单的轮子,需要自己去组装,整个设计经过以下几个思路

1.按照写文件字节数

需要现在内存中遍历所有需要解压文件的文件头信息获得解压后的大小,然后计算得出总大小,存在痛点:
(1)除了zip包外,其余解压格式基本需要将文件全部读一遍(zip的文件信息是集中存储到一起的,而其他文件不是),这将导致非zip包解压在到进度条之前会有较长阻塞,为了展示进度条而展示,忽略本身解压性能,没必要
(2)冗余的多次遍历,违背一定的软件设计理念

2.尝试按照读和写字节数及时修正比率

首先,按照压缩文件的类型,按照习惯的压缩比率估算一个大概的解压大小,然后按照实时读取的字节数和写入的字节数,即时地修正这个比率,也存在痛点:
(1)设置修正的时间点难以把握,本质上读取部分读完了才会开始准备写入数据
(2)压缩包文件数量越少,越难把握整体进度
(3)由于每个文件压缩比率的不同,经常会出现进度回调的情况

3.尝试按照读文件字节数

可以将整个读取过程读取的字节数和总字节数作比,虽然读完字节数不代表整个文件已经落盘,但是读也是有buf的,不是全部文件读入内存才开始写,如果buf设置的足够的小的话,就可以忽略这个误差,最终将整个文件写完作为最终100%的信号,这样的模拟程度可以达到最高。

代码

package extract

import (
	"archive/tar"
	"context"
	"fmt"
	"io"
	"os"
	"path/filepath"
	"slices"
	"strings"
	"time"

	"github.com/klauspost/compress/zip"

	"github.com/mholt/archiver/v3"
	"github.com/mholt/archives"
	"github.com/schollz/progressbar/v3"
)

// Extractor 解压器类
type Extractor struct {
	outputDir string
	handler   *archives.Extractor
	showName  string
}

// MonitoredReader 是一个包装了 io.Reader 的结构体,用于监控读取的字节数
type MonitoredReader struct {
	reader     io.Reader
	totalBytes int64
	startTime  time.Time
	bar        *progressbar.ProgressBar
	name       string // 用于标识这个 reader,方便日志输出
}

// Read 实现 io.Reader 接口
func (mr *MonitoredReader) Read(p []byte) (n int, err error) {
	n, err = mr.reader.Read(p)
	if err != nil {
		return n, err
	}
	mr.bar.Add64(int64(n))
	return n, err
}

func (mr *MonitoredReader) Seek(offset int64, whence int) (int64, error) {
	seeker, ok := mr.reader.(io.Seeker)
	if !ok {
		return 0, fmt.Errorf("reader does not support seeking")
	}
	return seeker.Seek(offset, whence)
}

func (mr *MonitoredReader) ReadAt(p []byte, off int64) (int, error) {
	readerAt, ok := mr.reader.(io.ReaderAt)
	if !ok {
		return 0, fmt.Errorf("reader does not support reading at offset")
	}
	n, err := readerAt.ReadAt(p, off)
	mr.bar.Add64(int64(n))
	return n, err
}

// NewExtractor 创建新的解压器实例
func NewExtractor(outputDir string, handler *archives.Extractor, showName string) *Extractor {
	return &Extractor{
		outputDir: outputDir,
		handler:   handler,
		showName:  showName,
	}
}

// 主解压函数
func ExtractArchive(ctx context.Context, extractName, archivePath, dest string) error {
	// 检查文件是否存在
	fileInfo, err := os.Stat(archivePath)
	if err != nil {
		return fmt.Errorf("无法访问文件 %s: %v", archivePath, err)
	}

	fmt.Printf("正在处理文件: %s (大小: %d 字节)\n", archivePath, fileInfo.Size())

	// 打开文件
	file, err := os.Open(archivePath)
	if err != nil {
		return fmt.Errorf("无法打开文件 %s: %v", archivePath, err)
	}
	defer file.Close()

	// 使用 archives 包识别文件格式
	format, stream, err := archives.Identify(ctx, archivePath, file)
	if err != nil {
		return fmt.Errorf("无法识别文件格式 %s: %v", archivePath, err)
	}
	// 判断是否是tar或者是zip类
	if strings.HasPrefix(format.Extension(), ".tar") || slices.Contains([]string{".zip", ".tgz"}, format.Extension()) {
		// 检查是否为解压器
		if handler, ok := format.(archives.Extractor); ok {
			// 某些格式(如 ZIP)需要 ReaderAt 和 Seeker,直接使用原始文件
			extract_handler := NewExtractor(dest, &handler, extractName)
			// 创建进度条
			bar := progressbar.NewOptions64(fileInfo.Size(),
				progressbar.OptionSetDescription(fmt.Sprintf("extract %s", extractName)),
				progressbar.OptionSetWriter(os.Stderr),
				progressbar.OptionShowBytes(true),
				progressbar.OptionSetWidth(50),
				progressbar.OptionThrottle(65*time.Millisecond),
				progressbar.OptionShowCount(),
				progressbar.OptionOnCompletion(func() {
					fmt.Fprint(os.Stderr, "\n")
				}),
				progressbar.OptionSpinnerType(14),
				progressbar.OptionFullWidth(),
				progressbar.OptionSetRenderBlankState(true),
				progressbar.OptionSetPredictTime(false),
			)
			monitoredReader := &MonitoredReader{
				reader:     stream,
				totalBytes: fileInfo.Size(),
				startTime:  time.Now(),
				name:       extractName,
				bar:        bar,
			}
			err := extract_handler.extractArchive(ctx, monitoredReader)
			bar.Finish()
			return err
		} else {
			return fmt.Errorf("无法识别文件格式 %s: %v", archivePath, err)
		}
	} else {
		// use archiver/v3 to extract
		err := archiver.Unarchive(archivePath, dest)
		if err != nil {
			return fmt.Errorf("解压失败: %v", err)
		}
		return nil
	}
}

// extractArchive 解压归档文件
func (e *Extractor) extractArchive(ctx context.Context, reader io.Reader) error {
	// 确保输出目录存在
	if err := os.MkdirAll(e.outputDir, 0755); err != nil {
		return fmt.Errorf("无法创建输出目录 %s: %v", e.outputDir, err)
	}

	// 定义文件处理函数
	f := func(ctx context.Context, f archives.FileInfo) error {
		// 检查上下文是否被取消
		if err := ctx.Err(); err != nil {
			return err
		}
		// 获取文件信息
		fileInfo := f.FileInfo

		var fullPath string
		var isSymlink bool
		var isHardlink bool
		var linkTarget string

		switch header := f.Header.(type) {
		case *tar.Header:
			fullPath = header.Name
			switch header.Typeflag {
			case tar.TypeSymlink:
				isSymlink = true
				linkTarget = header.Linkname
			case tar.TypeLink:
				isHardlink = true
				linkTarget = header.Linkname
			}
		case zip.FileHeader:
			fullPath = header.Name
			// ZIP 中符号链接通过文件模式标记,目标路径存储在文件内容
			if header.Mode()&os.ModeSymlink != 0 {
				isSymlink = true
				rc, err := f.Open()
				if err != nil {
					return fmt.Errorf("打开符号链接内容失败 %s: %v", fileInfo.Name(), err)
				}
				defer rc.Close()
				b, err := io.ReadAll(rc)
				if err != nil {
					return fmt.Errorf("读取符号链接目标失败 %s: %v", fileInfo.Name(), err)
				}
				linkTarget = string(b)
			}
		default:
			return fmt.Errorf("unknown header type: %T", header)
		}
		destPath := filepath.Join(e.outputDir, fullPath)

		// // 清理并验证文件路径,防止路径穿越攻击
		destPath = filepath.Clean(destPath)

		// 检查文件路径是否安全(防止 zip slip 攻击)
		if !strings.HasPrefix(destPath, filepath.Clean(e.outputDir)+string(os.PathSeparator)) &&
			destPath != filepath.Clean(e.outputDir) {
			return fmt.Errorf("unsafe file path: %s", fileInfo.Name())
		}

		// 如果是目录
		if fileInfo.IsDir() {
			// 创建目录
			err := os.MkdirAll(destPath, fileInfo.Mode())
			if err != nil {
				return fmt.Errorf("创建目录失败 %s: %v", destPath, err)
			}

			// 设置目录时间戳
			if modTime := fileInfo.ModTime(); !modTime.IsZero() {
				err = os.Chtimes(destPath, modTime, modTime)
				if err != nil {
					// 时间戳设置失败不是致命错误,仅记录警告
					fmt.Printf("警告: 无法设置目录时间戳 %s: %v\n", destPath, err)
				}
			}

			return nil
		}

		// 统一确保父目录存在
		parentDir := filepath.Dir(destPath)
		if parentDir != "." && parentDir != "/" {
			err := os.MkdirAll(parentDir, 0755)
			if err != nil {
				return fmt.Errorf("创建父目录失败 %s: %v", parentDir, err)
			}
		}

		// 处理符号链接(软链接)
		if isSymlink {
			// 如果目标路径已存在,先删除(遵循 archiver 的覆盖策略)
			if _, err := os.Lstat(destPath); err == nil {
				if err := os.Remove(destPath); err != nil {
					return fmt.Errorf("删除已存在路径失败 %s: %v", destPath, err)
				}
			}

			if err := os.Symlink(linkTarget, destPath); err != nil {
				return fmt.Errorf("创建符号链接失败 %s -> %s: %v", destPath, linkTarget, err)
			}

			// 不对符号链接调用 Chtimes,以避免更改目标文件时间
			return nil
		}

		// 处理硬链接(仅 tar)
		if isHardlink {
			// tar 的硬链接目标以归档根为基准
			targetAbs := filepath.Clean(filepath.Join(e.outputDir, linkTarget))

			// 如果目标路径已存在,先删除
			if _, err := os.Lstat(destPath); err == nil {
				if err := os.Remove(destPath); err != nil {
					return fmt.Errorf("删除已存在路径失败 %s: %v", destPath, err)
				}
			}

			if err := os.Link(targetAbs, destPath); err != nil {
				return fmt.Errorf("创建硬链接失败 %s -> %s: %v", destPath, targetAbs, err)
			}
			return nil
		}

		// 普通文件:打开源文件流
		srcReader, err := f.Open()
		if err != nil {
			return fmt.Errorf("打开压缩包中的文件失败 %s: %v", fileInfo.Name(), err)
		}
		defer srcReader.Close()

		// 创建目标文件
		destFile, err := os.OpenFile(destPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fileInfo.Mode())
		if err != nil {
			return fmt.Errorf("创建目标文件失败 %s: %v", destPath, err)
		}
		defer destFile.Close()

		// 复制文件内容
		_, err = io.Copy(destFile, srcReader)
		if err != nil {
			return fmt.Errorf("复制文件内容失败 %s: %v", destPath, err)
		}

		// 设置文件时间戳
		if modTime := fileInfo.ModTime(); !modTime.IsZero() {
			err = os.Chtimes(destPath, modTime, modTime)
			if err != nil {
				// 时间戳设置失败不是致命错误,仅记录警告
				fmt.Printf("警告: 无法设置文件时间戳 %s: %v\n", destPath, err)
			}
		}

		return nil
	}
	handler := *e.handler
	return handler.Extract(ctx, reader, f)
}