cos.jar 의 MultipartRequest를 확장한 확장자 체크 만들기.

Published on: 2011. 2. 15. 09:36 by louis.dev

그동안 파일 업로드를 할때 apache의 commons-fileupload를 사용하거나, oreilly의 MultipartRequest를 통해 업로드를 해왔습니다. 사실은 commons-fileupload 보다 MultipartRequest를 사용하는 것을 더 선호 했습니다.

왜냐!!

더 쉽기 때문이죠..^^;;

MultipartRequest를 통해 파일 업로드를 하려면, 단순히 MultipartRequest 인스턴스를 생성하는 것만으로도 파일업로드가 진행이 되었습니다.

이렇게 쉽게 파일 업로드를 하는 MultipartRequest에도 문제점은 가지고 있었으니.. 그것은 바로, 파일업로드를 진행하기 전, 파일의 확장자를 체크하지 못한다는 점입니다. 보안상 이슈로 인해 jsp, php, asp와 같은 파일들은 업로드 하지 못하게 막아야 하는데 ( 업로드 한후 업로드 한 파일을 실행하여 서버의 정보나 기타 다른 보안내용을 가져갈 수 있기 때문입니다..). 물론 스크립트 단에서 스크립트를 통해 확장자를 제어 하는 방법이 존재 하긴 하지만, 스크립트는 언제나 우회 가능 하기 때문에, 결국은 Backend에서 확장자를 체크하고, 확장자가 유효한지, 유효하지 않은지에 따라 업로드 가능 유무를 체크 해야 합니다.

이런 점을 들어, MultipartRequest를 개조하여 CheckExtentionMultipartRequest 를 만들어 보았습니다. 파일 Validation을 하는 것과 동시에, 소스 코드 리펙토링을 진행하여 코드를 작성해 보았습니다.

소스코드를 보다 보니 대부분의 Collection을 Vector로 작성을 하였는데, 아마 synchronized 때문에 Vector를 사용한 것으로 보여집니다. 하지만 자바 권고 안은 synchronized 가 걸린 Collection은
 
 Collections.synchronizedMap( new HashMap() )

을 통해 생성하라고 권고 하고 있기 때문에 Vector로 구현된 내용을 위와 같이 수정하였습니다.


CheckExtentionMultipartRequest

package net.tutorial;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Vector;

import javax.servlet.http.HttpServletRequest;

import net.daum.info.exception.InvalidFileExtensionException;

import com.oreilly.servlet.multipart.FilePart;
import com.oreilly.servlet.multipart.FileRenamePolicy;
import com.oreilly.servlet.multipart.MultipartParser;
import com.oreilly.servlet.multipart.ParamPart;
import com.oreilly.servlet.multipart.Part;

@SuppressWarnings("unchecked")
public class CheckExtentionMultipartRequest {
	//파라미터 저장 Map
	protected Map<Object,Object> parameters = Collections.synchronizedMap( new HashMap() ); 
	//file 저장 Map
	protected Map<Object,Object> files = Collections.synchronizedMap( new HashMap() ); 

	//기존의 파라미터들은 동일하고, String[] extentsion이 추가 됨.
	//.(컴마)를 제외한 제외하려는 확장자를 array 형태로 넘기면 됨(ex - {"php","jsp","asp"} )
	public CheckExtentionMultipartRequest(HttpServletRequest request,String saveDirectory, int maxPostSize, String encoding,FileRenamePolicy policy, String[] extensions) throws Exception {
		// Sanity check values
		if (request == null)
			throw new IllegalArgumentException("request cannot be null");
		if (saveDirectory == null)
			throw new IllegalArgumentException("saveDirectory cannot be null");
		if (maxPostSize <= 0) {
			throw new IllegalArgumentException("maxPostSize must be positive");
		}

		// Save the dir
		File dir = new File(saveDirectory);

		// 디렉토리 인지 체크
		if (!dir.isDirectory())
			throw new IllegalArgumentException("Not a directory: "+ saveDirectory);

		// write 할수 있는 디렉토리인지 체크
		if (!dir.canWrite())
			throw new IllegalArgumentException("Not writable: " + saveDirectory);

		MultipartParser parser = new MultipartParser(request, maxPostSize, true, true, encoding);

		if (request.getQueryString() != null) {
			// HttpUtil이 Deprecated 되었기 때문에 다른방식으로 작성						Map<String,Object> queryParameters = Collections.synchronizedMap( request.getParameterMap() );
			Iterator queryParameterNames = queryParameters.keySet().iterator();
			while( queryParameterNames.hasNext() ) {
				Object paramName = queryParameterNames.next();
				String[] values = ( String[] ) queryParameters.get( paramName );
				List<Object> newValues = Collections.synchronizedList( new ArrayList<Object>() );
				for( String value : values ) {
					newValues.add( value );
				}
				parameters.put(paramName, newValues ); 	
			}
		}

		Part part;
		while ((part = parser.readNextPart()) != null) {
			String name = part.getName();
			if (name == null) {
				throw new IOException("Malformed input: parameter name missing (known Opera 7 bug)");
			}
			if (part.isParam()) {
				// It's a parameter part, add it to the vector of values
				ParamPart paramPart = (ParamPart) part;
				String value = paramPart.getStringValue();
				List<Object> existingValues = (List<Object>) parameters.get(name);
				if (existingValues == null) {
					existingValues = new Vector();
					parameters.put(name, existingValues);
				}
				existingValues.add( value );
			} else if ( part.isFile() ) {
				// It's a file part
				FilePart filePart = (FilePart) part;
				String fileName = filePart.getFileName();
				if (fileName != null) {
					//file 확장자를 validation 합니다.
					//만약 유효하지 않은 확장자라면 InvalidFileExtensionException을 throw 합니다.
					if( !isValidFileExtension(fileName, extensions)) {
						throw new InvalidFileExtensionException( "Invalid File Extension" );
					}
					else {
						filePart.setRenamePolicy(policy); // null policy is OK
						filePart.writeTo(dir);
						files.put(name,
								new UploadedFile(dir.toString(), filePart
										.getFileName(), fileName, filePart
										.getContentType()));	
					}
					
				} else {
					// The field did not contain a file
					files.put(name, new UploadedFile(null, null, null, null));
				}
			}
		}
	}

	public Iterator<Object> getParameterNames() {
		return parameters.keySet().iterator();
	}

	public Iterator<Object> getFileNames() {
		return files.keySet().iterator();
	}

	public String getParameter(String name) {
		try {
			List<Object> values = (List<Object>) parameters.get(name);
			if (values == null || values.size() == 0) {
				return null;
			}
			String value = (String) values.get(values.size() - 1);
			return value;
		} catch (Exception e) {
			return null;
		}
	}

	public Object[] getParameterValues(String name) {
		try {
			List<Object> values = (List<Object>) parameters.get(name);
			if (values == null || values.size() == 0) {
				return null;
			}
			return values.toArray();
		} catch (Exception e) {
			return null;
		}
	}

	public String getFilesystemName(String name) {
		try {
			UploadedFile file = (UploadedFile) files.get(name);
			return file.getFilesystemName(); // may be null
		} catch (Exception e) {
			return null;
		}
	}

	public String getOriginalFileName(String name) {
		try {
			UploadedFile file = (UploadedFile) files.get(name);
			return file.getOriginalFileName(); // may be null
		} catch (Exception e) {
			return null;
		}
	}

	public String getContentType(String name) {
		try {
			UploadedFile file = (UploadedFile) files.get(name);
			return file.getContentType(); // may be null
		} catch (Exception e) {
			return null;
		}
	}

	public File getFile(String name) {
		try {
			UploadedFile file = (UploadedFile) files.get(name);
			return file.getFile(); // may be null
		} catch (Exception e) {
			return null;
		}
	}
	
	public boolean isValidFileExtension( String fileName ,String[] extensions) throws Exception{
		boolean result = false;
		if( fileName != null && !"".equals( fileName ) && !isContainFileExtension(fileName, extensions)) {
			result = true;
		}
		else {
			result = false;
		}
		return result;
	}
	private boolean isContainFileExtension( String fileName, String[] extensions ) throws Exception{
		boolean result = false;
		String fileExtension = getFileExtension( fileName );
		for( String ex : extensions ) {
			if( fileExtension.equals( ex ) ) {
				result = true;
				break;
			}
		}
		return result;
	}
	private String getFileExtension( String fileName ) throws Exception{
		String fileExtension = "";
		if( fileName != null && !"".equals( fileName )) {
			if( fileName.lastIndexOf( "." ) != -1){
				fileExtension = fileName.toLowerCase().substring( fileName.lastIndexOf( "." ) + 1, fileName.length() );
			}
			else {
				fileExtension = "";
			}
		}else{
			fileExtension = "";
		}
		return fileExtension;
	}
}
class UploadedFile {

	private String dir;
	private String filename;
	private String original;
	private String type;

	UploadedFile(String dir, String filename, String original, String type) {
		this.dir = dir;
		this.filename = filename;
		this.original = original;
		this.type = type;
	}

	public String getContentType() {
		return type;
	}

	public String getFilesystemName() {
		return filename;
	}

	public String getOriginalFileName() {
		return original;
	}

	public File getFile() {
		if (dir == null || filename == null) {
			return null;
		} else {
			return new File(dir + File.separator + filename);
		}
	}
}


위와 같이 작성하면 됩니다. InvalidFileExtensionException은 간단하게 다음과 같이 작성하시면 됩니다.

package net.tutorial.exception

@SuppressWarnings("serial")
public class InvalidFileExtensionException extends Exception {

	public InvalidFileExtensionException(String message) {
		super(message);
	}

}


실질 적으로 사용하는 방법은 다음과 같이 사용하시면 됩니다.
ExtensionCheckMultipartRequest multipartReqeust = null;
	try {
		String ex = {"php","jsp","asp"};
		multipartReqeust = new ExtensionCheckMultipartRequest(request, "파일업로드 패스", 5 * 1024 * 1024, "utf-8", new DefaultFileRenamePolicy(), ex);
	}catch ( InvalidFileExtensionException e) {
		e.printStackTrace();
		//TODO
	}